From 444382ea3a54b1b7f8066b155a2da64cc03775af Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Wed, 17 Jun 2015 22:06:44 -0700 Subject: [PATCH] remove expensive api from InternalRow --- .../java/org/apache/spark/sql/BaseRow.java | 106 +--------- .../main/scala/org/apache/spark/sql/Row.scala | 99 ++-------- .../spark/sql/catalyst/InternalRow.scala | 184 ++++++++++++++++-- .../sql/catalyst/expressions/Projection.scala | 36 ---- .../expressions/SpecificMutableRow.scala | 5 +- .../expressions/UnsafeRowConverter.scala | 2 - .../sql/catalyst/expressions/generators.scala | 8 +- .../spark/sql/catalyst/expressions/rows.scala | 12 +- .../org/apache/spark/sql/SQLContext.scala | 3 +- .../spark/sql/columnar/ColumnType.scala | 72 +++---- .../columnar/InMemoryColumnarTableScan.scala | 3 +- .../sql/execution/SparkSqlSerializer2.scala | 5 +- .../spark/sql/execution/pythonUdfs.scala | 4 +- .../sql/execution/stat/StatFunctions.scala | 3 +- .../org/apache/spark/sql/jdbc/JDBCRDD.scala | 2 +- .../org/apache/spark/sql/jdbc/jdbc.scala | 2 +- .../apache/spark/sql/parquet/newParquet.scala | 4 +- .../apache/spark/sql/sources/commands.scala | 39 ++-- .../org/apache/spark/sql/sources/ddl.scala | 12 +- .../apache/spark/sql/sources/interfaces.scala | 10 + .../spark/sql/sources/DDLTestSuite.scala | 3 +- .../spark/sql/sources/TableScanSuite.scala | 6 +- .../apache/spark/sql/hive/TableReader.scala | 3 +- .../hive/execution/CreateTableAsSelect.scala | 14 +- .../execution/DescribeHiveTableCommand.scala | 8 +- .../hive/execution/HiveNativeCommand.scala | 8 +- .../hive/execution/InsertIntoHiveTable.scala | 4 +- .../spark/sql/hive/execution/commands.scala | 37 ++-- .../spark/sql/hive/hiveWriterContainers.scala | 15 +- .../spark/sql/hive/orc/OrcRelation.scala | 14 +- 30 files changed, 345 insertions(+), 378 deletions(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/BaseRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/BaseRow.java index 611e02d8fb666..69d0848e2da83 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/BaseRow.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/BaseRow.java @@ -17,11 +17,6 @@ package org.apache.spark.sql; -import java.math.BigDecimal; -import java.sql.Date; -import java.sql.Timestamp; -import java.util.List; - import scala.collection.Seq; import scala.collection.mutable.ArraySeq; @@ -36,17 +31,6 @@ final public int length() { return size(); } - @Override - public boolean anyNull() { - final int n = size(); - for (int i=0; i < n; i++) { - if (isNullAt(i)) { - return true; - } - } - return false; - } - @Override public StructType schema() { throw new UnsupportedOperationException(); } @@ -90,78 +74,13 @@ public boolean getBoolean(int i) { throw new UnsupportedOperationException(); } - @Override - public String getString(int i) { - throw new UnsupportedOperationException(); - } - - @Override - public BigDecimal getDecimal(int i) { - throw new UnsupportedOperationException(); - } - - @Override - public Date getDate(int i) { - throw new UnsupportedOperationException(); - } - - @Override - public Timestamp getTimestamp(int i) { - throw new UnsupportedOperationException(); - } - - @Override - public Seq getSeq(int i) { - throw new UnsupportedOperationException(); - } - - @Override - public List getList(int i) { - throw new UnsupportedOperationException(); - } - - @Override - public scala.collection.Map getMap(int i) { - throw new UnsupportedOperationException(); - } - - @Override - public scala.collection.immutable.Map getValuesMap(Seq fieldNames) { - throw new UnsupportedOperationException(); - } - - @Override - public java.util.Map getJavaMap(int i) { - throw new UnsupportedOperationException(); - } - - @Override - public Row getStruct(int i) { - throw new UnsupportedOperationException(); - } - - @Override - public T getAs(int i) { - throw new UnsupportedOperationException(); - } - - @Override - public T getAs(String fieldName) { - throw new UnsupportedOperationException(); - } - - @Override - public int fieldIndex(String name) { - throw new UnsupportedOperationException(); - } - /** * A generic version of Row.equals(Row), which is used for tests. */ @Override public boolean equals(Object other) { - if (other instanceof Row) { - Row row = (Row) other; + if (other instanceof InternalRow) { + InternalRow row = (InternalRow) other; int n = size(); if (n != row.size()) { return false; @@ -186,7 +105,6 @@ public InternalRow copy() { return new GenericRow(arr); } - @Override public Seq toSeq() { final int n = size(); final ArraySeq values = new ArraySeq(n); @@ -195,24 +113,4 @@ public Seq toSeq() { } return values; } - - @Override - public String toString() { - return mkString("[", ",", "]"); - } - - @Override - public String mkString() { - return toSeq().mkString(); - } - - @Override - public String mkString(String sep) { - return toSeq().mkString(sep); - } - - @Override - public String mkString(String start, String sep, String end) { - return toSeq().mkString(start, sep, end); - } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala index 8aaf5d7d89154..679eb5578d81b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala @@ -17,8 +17,7 @@ package org.apache.spark.sql -import scala.util.hashing.MurmurHash3 - +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.GenericRow import org.apache.spark.sql.types.StructType @@ -118,17 +117,17 @@ object Row { * * @group row */ -trait Row extends Serializable { +trait Row extends InternalRow { /** Number of elements in the Row. */ - def size: Int = length + override def size: Int = length /** Number of elements in the Row. */ - def length: Int + override def length: Int /** * Schema for the row. */ - def schema: StructType = null + override def schema: StructType = null /** * Returns the value at position i. If the value is null, null is returned. The following @@ -153,7 +152,7 @@ trait Row extends Serializable { * StructType -> org.apache.spark.sql.Row * }}} */ - def apply(i: Int): Any + override def apply(i: Int): Any /** * Returns the value at position i. If the value is null, null is returned. The following @@ -178,10 +177,10 @@ trait Row extends Serializable { * StructType -> org.apache.spark.sql.Row * }}} */ - def get(i: Int): Any = apply(i) + override def get(i: Int): Any = apply(i) /** Checks whether the value at position i is null. */ - def isNullAt(i: Int): Boolean + override def isNullAt(i: Int): Boolean /** * Returns the value at position i as a primitive boolean. @@ -189,7 +188,7 @@ trait Row extends Serializable { * @throws ClassCastException when data type does not match. * @throws NullPointerException when value is null. */ - def getBoolean(i: Int): Boolean + override def getBoolean(i: Int): Boolean /** * Returns the value at position i as a primitive byte. @@ -197,7 +196,7 @@ trait Row extends Serializable { * @throws ClassCastException when data type does not match. * @throws NullPointerException when value is null. */ - def getByte(i: Int): Byte + override def getByte(i: Int): Byte /** * Returns the value at position i as a primitive short. @@ -205,7 +204,7 @@ trait Row extends Serializable { * @throws ClassCastException when data type does not match. * @throws NullPointerException when value is null. */ - def getShort(i: Int): Short + override def getShort(i: Int): Short /** * Returns the value at position i as a primitive int. @@ -213,7 +212,7 @@ trait Row extends Serializable { * @throws ClassCastException when data type does not match. * @throws NullPointerException when value is null. */ - def getInt(i: Int): Int + override def getInt(i: Int): Int /** * Returns the value at position i as a primitive long. @@ -221,7 +220,7 @@ trait Row extends Serializable { * @throws ClassCastException when data type does not match. * @throws NullPointerException when value is null. */ - def getLong(i: Int): Long + override def getLong(i: Int): Long /** * Returns the value at position i as a primitive float. @@ -230,7 +229,7 @@ trait Row extends Serializable { * @throws ClassCastException when data type does not match. * @throws NullPointerException when value is null. */ - def getFloat(i: Int): Float + override def getFloat(i: Int): Float /** * Returns the value at position i as a primitive double. @@ -238,7 +237,7 @@ trait Row extends Serializable { * @throws ClassCastException when data type does not match. * @throws NullPointerException when value is null. */ - def getDouble(i: Int): Double + override def getDouble(i: Int): Double /** * Returns the value at position i as a String object. @@ -246,7 +245,7 @@ trait Row extends Serializable { * @throws ClassCastException when data type does not match. * @throws NullPointerException when value is null. */ - def getString(i: Int): String + override def getString(i: Int): String /** * Returns the value at position i of decimal type as java.math.BigDecimal. @@ -313,7 +312,7 @@ trait Row extends Serializable { * * @throws ClassCastException when data type does not match. */ - def getAs[T](i: Int): T = apply(i).asInstanceOf[T] + override def getAs[T](i: Int): T = apply(i).asInstanceOf[T] /** * Returns the value of a given fieldName. @@ -347,70 +346,8 @@ trait Row extends Serializable { }.toMap } - override def toString(): String = s"[${this.mkString(",")}]" - /** * Make a copy of the current [[Row]] object. */ - def copy(): Row - - /** Returns true if there are any NULL values in this row. */ - def anyNull: Boolean = { - val len = length - var i = 0 - while (i < len) { - if (isNullAt(i)) { return true } - i += 1 - } - false - } - - override def equals(that: Any): Boolean = that match { - case null => false - case that: Row => - if (this.length != that.length) { - return false - } - var i = 0 - val len = this.length - while (i < len) { - if (apply(i) != that.apply(i)) { - return false - } - i += 1 - } - true - case _ => false - } - - override def hashCode: Int = { - // Using Scala's Seq hash code implementation. - var n = 0 - var h = MurmurHash3.seqSeed - val len = length - while (n < len) { - h = MurmurHash3.mix(h, apply(n).##) - n += 1 - } - MurmurHash3.finalizeHash(h, n) - } - - /* ---------------------- utility methods for Scala ---------------------- */ - - /** - * Return a Scala Seq representing the row. ELements are placed in the same order in the Seq. - */ - def toSeq: Seq[Any] - - /** Displays all elements of this sequence in a string (without a separator). */ - def mkString: String = toSeq.mkString - - /** Displays all elements of this sequence in a string using a separator string. */ - def mkString(sep: String): String = toSeq.mkString(sep) - - /** - * Displays all elements of this traversable or iterator in a string using - * start, end, and separator strings. - */ - def mkString(start: String, sep: String, end: String): String = toSeq.mkString(start, sep, end) + override def copy(): Row } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala index e3c2cc243310b..9b4f8aafa1092 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala @@ -19,39 +19,193 @@ package org.apache.spark.sql.catalyst import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.expressions.GenericRow +import org.apache.spark.sql.types.StructType +import org.apache.spark.unsafe.types.UTF8String /** * An abstract class for row used internal in Spark SQL, which only contain the columns as * internal types. + * + * The following is a mapping between Spark SQL types and types of objects in row: + * + * BooleanType -> java.lang.Boolean + * ByteType -> java.lang.Byte + * ShortType -> java.lang.Short + * IntegerType -> java.lang.Integer + * FloatType -> java.lang.Float + * DoubleType -> java.lang.Double + * StringType -> UTF8String + * DecimalType -> org.apache.spark.sql.types.Decimal + * + * DateType -> java.lang.Int + * TimestampType -> java.lang.Long + * + * BinaryType -> Array[Byte] + * ArrayType -> scala.collection.Seq + * MapType -> scala.collection.Map + * StructType -> InternalRow */ -abstract class InternalRow extends Row { - // A default implementation to change the return type - override def copy(): InternalRow = {this} -} +abstract class InternalRow extends Serializable { + /** Number of elements in the Row. */ + def size: Int = length -object InternalRow { - def unapplySeq(row: InternalRow): Some[Seq[Any]] = Some(row.toSeq) + /** Number of elements in the Row. */ + def length: Int /** - * This method can be used to construct a [[Row]] with the given values. + * Schema for the row. */ - def apply(values: Any*): InternalRow = new GenericRow(values.toArray) + def schema: StructType = null /** - * This method can be used to construct a [[Row]] from a [[Seq]] of values. + * Returns the value at position i. If the value is null, null is returned. */ - def fromSeq(values: Seq[Any]): InternalRow = new GenericRow(values.toArray) + def apply(i: Int): Any + + /** + * Returns the value at position i. If the value is null, null is returned. + */ + def get(i: Int): Any = apply(i) + + /** Checks whether the value at position i is null. */ + def isNullAt(i: Int): Boolean + + /** + * Returns the value at position i as a primitive boolean. + * + * @throws ClassCastException when data type does not match. + * @throws NullPointerException when value is null. + */ + def getBoolean(i: Int): Boolean + + /** + * Returns the value at position i as a primitive byte. + * + * @throws ClassCastException when data type does not match. + * @throws NullPointerException when value is null. + */ + def getByte(i: Int): Byte + + /** + * Returns the value at position i as a primitive short. + * + * @throws ClassCastException when data type does not match. + * @throws NullPointerException when value is null. + */ + def getShort(i: Int): Short + + /** + * Returns the value at position i as a primitive int. + * + * @throws ClassCastException when data type does not match. + * @throws NullPointerException when value is null. + */ + def getInt(i: Int): Int + + /** + * Returns the value at position i as a primitive long. + * + * @throws ClassCastException when data type does not match. + * @throws NullPointerException when value is null. + */ + def getLong(i: Int): Long - def fromTuple(tuple: Product): InternalRow = fromSeq(tuple.productIterator.toSeq) + /** + * Returns the value at position i as a primitive float. + * Throws an exception if the type mismatches or if the value is null. + * + * @throws ClassCastException when data type does not match. + * @throws NullPointerException when value is null. + */ + def getFloat(i: Int): Float /** - * Merge multiple rows into a single row, one after another. + * Returns the value at position i as a primitive double. + * + * @throws ClassCastException when data type does not match. + * @throws NullPointerException when value is null. */ - def merge(rows: InternalRow*): InternalRow = { - // TODO: Improve the performance of this if used in performance critical part. - new GenericRow(rows.flatMap(_.toSeq).toArray) + def getDouble(i: Int): Double + + /** + * Returns the value at position i as a String object. + * + * @throws ClassCastException when data type does not match. + * @throws NullPointerException when value is null. + */ + def getString(i: Int): String = getAs[UTF8String](i).toString + + /** + * Returns the value at position i. + * + * @throws ClassCastException when data type does not match. + */ + def getAs[T](i: Int): T = apply(i).asInstanceOf[T] + + override def toString: String = s"[${this.mkString(",")}]" + + /** Returns true if there are any NULL values in this row. */ + def anyNull: Boolean = { + val len = length + var i = 0 + while (i < len) { + if (isNullAt(i)) { return true } + i += 1 + } + false + } + + override def equals(that: Any): Boolean = that match { + case null => false + case that: InternalRow => + if (this.length != that.length) { + return false + } + var i = 0 + val len = this.length + while (i < len) { + if (apply(i) != that.apply(i)) { + return false + } + i += 1 + } + true + case _ => false } + def copy(): InternalRow + + /* ---------------------- utility methods for Scala ---------------------- */ + + /** + * Return a Scala Seq representing the row. ELements are placed in the same order in the Seq. + */ + def toSeq: Seq[Any] + + /** Displays all elements of this sequence in a string (without a separator). */ + def mkString: String = toSeq.mkString + + /** Displays all elements of this sequence in a string using a separator string. */ + def mkString(sep: String): String = toSeq.mkString(sep) + + /** + * Displays all elements of this traversable or iterator in a string using + * start, end, and separator strings. + */ + def mkString(start: String, sep: String, end: String): String = toSeq.mkString(start, sep, end) +} + +object InternalRow { + /** + * This method can be used to construct a [[Row]] with the given values. + */ + def apply(values: Any*): InternalRow = new GenericRow(values.toArray) + + /** + * This method can be used to construct a [[Row]] from a [[Seq]] of values. + */ + def fromSeq(values: Seq[Any]): InternalRow = new GenericRow(values.toArray) + /** Returns an empty row. */ val empty = apply() } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala index d5967438ccb5a..94af021c943d3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala @@ -135,12 +135,6 @@ class JoinedRow extends InternalRow { override def getFloat(i: Int): Float = if (i < row1.length) row1.getFloat(i) else row2.getFloat(i - row1.length) - override def getString(i: Int): String = - if (i < row1.length) row1.getString(i) else row2.getString(i - row1.length) - - override def getAs[T](i: Int): T = - if (i < row1.length) row1.getAs[T](i) else row2.getAs[T](i - row1.length) - override def copy(): InternalRow = { val totalSize = row1.length + row2.length val copiedValues = new Array[Any](totalSize) @@ -235,12 +229,6 @@ class JoinedRow2 extends InternalRow { override def getFloat(i: Int): Float = if (i < row1.length) row1.getFloat(i) else row2.getFloat(i - row1.length) - override def getString(i: Int): String = - if (i < row1.length) row1.getString(i) else row2.getString(i - row1.length) - - override def getAs[T](i: Int): T = - if (i < row1.length) row1.getAs[T](i) else row2.getAs[T](i - row1.length) - override def copy(): InternalRow = { val totalSize = row1.length + row2.length val copiedValues = new Array[Any](totalSize) @@ -329,12 +317,6 @@ class JoinedRow3 extends InternalRow { override def getFloat(i: Int): Float = if (i < row1.length) row1.getFloat(i) else row2.getFloat(i - row1.length) - override def getString(i: Int): String = - if (i < row1.length) row1.getString(i) else row2.getString(i - row1.length) - - override def getAs[T](i: Int): T = - if (i < row1.length) row1.getAs[T](i) else row2.getAs[T](i - row1.length) - override def copy(): InternalRow = { val totalSize = row1.length + row2.length val copiedValues = new Array[Any](totalSize) @@ -423,12 +405,6 @@ class JoinedRow4 extends InternalRow { override def getFloat(i: Int): Float = if (i < row1.length) row1.getFloat(i) else row2.getFloat(i - row1.length) - override def getString(i: Int): String = - if (i < row1.length) row1.getString(i) else row2.getString(i - row1.length) - - override def getAs[T](i: Int): T = - if (i < row1.length) row1.getAs[T](i) else row2.getAs[T](i - row1.length) - override def copy(): InternalRow = { val totalSize = row1.length + row2.length val copiedValues = new Array[Any](totalSize) @@ -517,12 +493,6 @@ class JoinedRow5 extends InternalRow { override def getFloat(i: Int): Float = if (i < row1.length) row1.getFloat(i) else row2.getFloat(i - row1.length) - override def getString(i: Int): String = - if (i < row1.length) row1.getString(i) else row2.getString(i - row1.length) - - override def getAs[T](i: Int): T = - if (i < row1.length) row1.getAs[T](i) else row2.getAs[T](i - row1.length) - override def copy(): InternalRow = { val totalSize = row1.length + row2.length val copiedValues = new Array[Any](totalSize) @@ -611,12 +581,6 @@ class JoinedRow6 extends InternalRow { override def getFloat(i: Int): Float = if (i < row1.length) row1.getFloat(i) else row2.getFloat(i - row1.length) - override def getString(i: Int): String = - if (i < row1.length) row1.getString(i) else row2.getString(i - row1.length) - - override def getAs[T](i: Int): T = - if (i < row1.length) row1.getAs[T](i) else row2.getAs[T](i - row1.length) - override def copy(): InternalRow = { val totalSize = row1.length + row2.length val copiedValues = new Array[Any](totalSize) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala index 05aab34559985..5171727256799 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst.expressions +import org.apache.spark.sql.Row import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -191,7 +192,7 @@ final class MutableAny extends MutableValue { * based on the dataTypes of each column. The intent is to decrease garbage when modifying the * values of primitive columns. */ -final class SpecificMutableRow(val values: Array[MutableValue]) extends MutableRow { +final class SpecificMutableRow(val values: Array[MutableValue]) extends Row with MutableRow { def this(dataTypes: Seq[DataType]) = this( @@ -222,7 +223,7 @@ final class SpecificMutableRow(val values: Array[MutableValue]) extends MutableR override def isNullAt(i: Int): Boolean = values(i).isNull - override def copy(): InternalRow = { + override def copy(): Row = { val newValues = new Array[Any](values.length) var i = 0 while (i < values.length) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverter.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverter.scala index 72f740ecaead3..439cb3a70a97e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverter.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverter.scala @@ -17,8 +17,6 @@ package org.apache.spark.sql.catalyst.expressions -import org.apache.spark.sql.catalyst.util.DateUtils -import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.types._ import org.apache.spark.unsafe.PlatformDependent import org.apache.spark.unsafe.array.ByteArrayMethods diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala index f30cb42d12b83..48bdfa66d9ff8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions import scala.collection.Map -import org.apache.spark.sql.catalyst +import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.{CatalystTypeConverters, trees} import org.apache.spark.sql.types._ @@ -68,19 +68,19 @@ abstract class Generator extends Expression { */ case class UserDefinedGenerator( elementTypes: Seq[(DataType, Boolean)], - function: InternalRow => TraversableOnce[InternalRow], + function: Row => TraversableOnce[InternalRow], children: Seq[Expression]) extends Generator { @transient private[this] var inputRow: InterpretedProjection = _ - @transient private[this] var convertToScala: (InternalRow) => InternalRow = _ + @transient private[this] var convertToScala: InternalRow => Row = _ private def initializeConverters(): Unit = { inputRow = new InterpretedProjection(children) convertToScala = { val inputSchema = StructType(children.map(e => StructField(e.simpleString, e.dataType, true))) CatalystTypeConverters.createToScalaConverter(inputSchema) - }.asInstanceOf[(InternalRow => InternalRow)] + }.asInstanceOf[InternalRow => Row] } override def eval(input: InternalRow): TraversableOnce[InternalRow] = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala index 1098962ddc018..b0938f42d4025 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst.expressions +import org.apache.spark.sql.Row import org.apache.spark.sql.types.{DataType, StructType, AtomicType} import org.apache.spark.unsafe.types.UTF8String @@ -36,6 +37,7 @@ trait MutableRow extends InternalRow { def setShort(ordinal: Int, value: Short) def setByte(ordinal: Int, value: Byte) def setFloat(ordinal: Int, value: Float) + // only used for tests def setString(ordinal: Int, value: String) } @@ -54,7 +56,6 @@ object EmptyRow extends InternalRow { override def getBoolean(i: Int): Boolean = throw new UnsupportedOperationException override def getShort(i: Int): Short = throw new UnsupportedOperationException override def getByte(i: Int): Byte = throw new UnsupportedOperationException - override def getString(i: Int): String = throw new UnsupportedOperationException override def getAs[T](i: Int): T = throw new UnsupportedOperationException override def copy(): InternalRow = this } @@ -64,7 +65,7 @@ object EmptyRow extends InternalRow { * the array is not copied, and thus could technically be mutated after creation, this is not * allowed. */ -class GenericRow(protected[sql] val values: Array[Any]) extends InternalRow { +class GenericRow(protected[sql] val values: Array[Any]) extends Row { /** No-arg constructor for serialization. */ protected def this() = this(null) @@ -114,6 +115,7 @@ class GenericRow(protected[sql] val values: Array[Any]) extends InternalRow { } override def getString(i: Int): String = { + // the objects in values could be internal type or public types, so we need to check that. values(i) match { case null => null case s: String => s @@ -121,8 +123,6 @@ class GenericRow(protected[sql] val values: Array[Any]) extends InternalRow { } } - // TODO(davies): add getDate and getDecimal - // Custom hashCode function that matches the efficient code generated version. override def hashCode: Int = { var result: Int = 37 @@ -173,7 +173,7 @@ class GenericRow(protected[sql] val values: Array[Any]) extends InternalRow { case _ => false } - override def copy(): InternalRow = this + override def copy(): Row = this } class GenericRowWithSchema(values: Array[Any], override val schema: StructType) @@ -207,7 +207,7 @@ class GenericMutableRow(v: Array[Any]) extends GenericRow(v) with MutableRow { override def update(ordinal: Int, value: Any): Unit = { values(ordinal) = value } - override def copy(): InternalRow = new GenericRow(values.clone()) + override def copy(): Row = new GenericRow(values.clone()) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 9d1f89d6d7bd8..88d7eeb55693f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -41,6 +41,7 @@ import org.apache.spark.sql.catalyst.{InternalRow, ParserDialect, _} import org.apache.spark.sql.execution.{Filter, _} import org.apache.spark.sql.sources._ import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.Utils /** @@ -389,7 +390,7 @@ class SQLContext(@transient val sparkContext: SparkContext) val rows = data.mapPartitions { iter => val row = new SpecificMutableRow(dataType :: Nil) iter.map { v => - row.setString(0, v) + row.update(0, UTF8String.fromString(v)) row: Row } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala index 8e21020917768..c52f77c55e6c4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala @@ -21,7 +21,7 @@ import java.nio.ByteBuffer import scala.reflect.runtime.universe.TypeTag -import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.MutableRow import org.apache.spark.sql.execution.SparkSqlSerializer import org.apache.spark.sql.types._ @@ -63,7 +63,7 @@ private[sql] sealed abstract class ColumnType[T <: DataType, JvmType]( * Appends `row(ordinal)` of type T into the given ByteBuffer. Subclasses should override this * method to avoid boxing/unboxing costs whenever possible. */ - def append(row: Row, ordinal: Int, buffer: ByteBuffer): Unit = { + def append(row: InternalRow, ordinal: Int, buffer: ByteBuffer): Unit = { append(getField(row, ordinal), buffer) } @@ -71,13 +71,13 @@ private[sql] sealed abstract class ColumnType[T <: DataType, JvmType]( * Returns the size of the value `row(ordinal)`. This is used to calculate the size of variable * length types such as byte arrays and strings. */ - def actualSize(row: Row, ordinal: Int): Int = defaultSize + def actualSize(row: InternalRow, ordinal: Int): Int = defaultSize /** * Returns `row(ordinal)`. Subclasses should override this method to avoid boxing/unboxing costs * whenever possible. */ - def getField(row: Row, ordinal: Int): JvmType + def getField(row: InternalRow, ordinal: Int): JvmType /** * Sets `row(ordinal)` to `field`. Subclasses should override this method to avoid boxing/unboxing @@ -89,7 +89,7 @@ private[sql] sealed abstract class ColumnType[T <: DataType, JvmType]( * Copies `from(fromOrdinal)` to `to(toOrdinal)`. Subclasses should override this method to avoid * boxing/unboxing costs whenever possible. */ - def copyField(from: Row, fromOrdinal: Int, to: MutableRow, toOrdinal: Int): Unit = { + def copyField(from: InternalRow, fromOrdinal: Int, to: MutableRow, toOrdinal: Int): Unit = { to(toOrdinal) = from(fromOrdinal) } @@ -118,7 +118,7 @@ private[sql] object INT extends NativeColumnType(IntegerType, 0, 4) { buffer.putInt(v) } - override def append(row: Row, ordinal: Int, buffer: ByteBuffer): Unit = { + override def append(row: InternalRow, ordinal: Int, buffer: ByteBuffer): Unit = { buffer.putInt(row.getInt(ordinal)) } @@ -134,9 +134,9 @@ private[sql] object INT extends NativeColumnType(IntegerType, 0, 4) { row.setInt(ordinal, value) } - override def getField(row: Row, ordinal: Int): Int = row.getInt(ordinal) + override def getField(row: InternalRow, ordinal: Int): Int = row.getInt(ordinal) - override def copyField(from: Row, fromOrdinal: Int, to: MutableRow, toOrdinal: Int): Unit = { + override def copyField(from: InternalRow, fromOrdinal: Int, to: MutableRow, toOrdinal: Int) { to.setInt(toOrdinal, from.getInt(fromOrdinal)) } } @@ -146,7 +146,7 @@ private[sql] object LONG extends NativeColumnType(LongType, 1, 8) { buffer.putLong(v) } - override def append(row: Row, ordinal: Int, buffer: ByteBuffer): Unit = { + override def append(row: InternalRow, ordinal: Int, buffer: ByteBuffer): Unit = { buffer.putLong(row.getLong(ordinal)) } @@ -162,9 +162,9 @@ private[sql] object LONG extends NativeColumnType(LongType, 1, 8) { row.setLong(ordinal, value) } - override def getField(row: Row, ordinal: Int): Long = row.getLong(ordinal) + override def getField(row: InternalRow, ordinal: Int): Long = row.getLong(ordinal) - override def copyField(from: Row, fromOrdinal: Int, to: MutableRow, toOrdinal: Int): Unit = { + override def copyField(from: InternalRow, fromOrdinal: Int, to: MutableRow, toOrdinal: Int) { to.setLong(toOrdinal, from.getLong(fromOrdinal)) } } @@ -174,7 +174,7 @@ private[sql] object FLOAT extends NativeColumnType(FloatType, 2, 4) { buffer.putFloat(v) } - override def append(row: Row, ordinal: Int, buffer: ByteBuffer): Unit = { + override def append(row: InternalRow, ordinal: Int, buffer: ByteBuffer): Unit = { buffer.putFloat(row.getFloat(ordinal)) } @@ -190,9 +190,9 @@ private[sql] object FLOAT extends NativeColumnType(FloatType, 2, 4) { row.setFloat(ordinal, value) } - override def getField(row: Row, ordinal: Int): Float = row.getFloat(ordinal) + override def getField(row: InternalRow, ordinal: Int): Float = row.getFloat(ordinal) - override def copyField(from: Row, fromOrdinal: Int, to: MutableRow, toOrdinal: Int): Unit = { + override def copyField(from: InternalRow, fromOrdinal: Int, to: MutableRow, toOrdinal: Int) { to.setFloat(toOrdinal, from.getFloat(fromOrdinal)) } } @@ -202,7 +202,7 @@ private[sql] object DOUBLE extends NativeColumnType(DoubleType, 3, 8) { buffer.putDouble(v) } - override def append(row: Row, ordinal: Int, buffer: ByteBuffer): Unit = { + override def append(row: InternalRow, ordinal: Int, buffer: ByteBuffer): Unit = { buffer.putDouble(row.getDouble(ordinal)) } @@ -218,9 +218,9 @@ private[sql] object DOUBLE extends NativeColumnType(DoubleType, 3, 8) { row.setDouble(ordinal, value) } - override def getField(row: Row, ordinal: Int): Double = row.getDouble(ordinal) + override def getField(row: InternalRow, ordinal: Int): Double = row.getDouble(ordinal) - override def copyField(from: Row, fromOrdinal: Int, to: MutableRow, toOrdinal: Int): Unit = { + override def copyField(from: InternalRow, fromOrdinal: Int, to: MutableRow, toOrdinal: Int) { to.setDouble(toOrdinal, from.getDouble(fromOrdinal)) } } @@ -230,7 +230,7 @@ private[sql] object BOOLEAN extends NativeColumnType(BooleanType, 4, 1) { buffer.put(if (v) 1: Byte else 0: Byte) } - override def append(row: Row, ordinal: Int, buffer: ByteBuffer): Unit = { + override def append(row: InternalRow, ordinal: Int, buffer: ByteBuffer): Unit = { buffer.put(if (row.getBoolean(ordinal)) 1: Byte else 0: Byte) } @@ -244,9 +244,9 @@ private[sql] object BOOLEAN extends NativeColumnType(BooleanType, 4, 1) { row.setBoolean(ordinal, value) } - override def getField(row: Row, ordinal: Int): Boolean = row.getBoolean(ordinal) + override def getField(row: InternalRow, ordinal: Int): Boolean = row.getBoolean(ordinal) - override def copyField(from: Row, fromOrdinal: Int, to: MutableRow, toOrdinal: Int): Unit = { + override def copyField(from: InternalRow, fromOrdinal: Int, to: MutableRow, toOrdinal: Int) { to.setBoolean(toOrdinal, from.getBoolean(fromOrdinal)) } } @@ -256,7 +256,7 @@ private[sql] object BYTE extends NativeColumnType(ByteType, 5, 1) { buffer.put(v) } - override def append(row: Row, ordinal: Int, buffer: ByteBuffer): Unit = { + override def append(row: InternalRow, ordinal: Int, buffer: ByteBuffer): Unit = { buffer.put(row.getByte(ordinal)) } @@ -272,9 +272,9 @@ private[sql] object BYTE extends NativeColumnType(ByteType, 5, 1) { row.setByte(ordinal, value) } - override def getField(row: Row, ordinal: Int): Byte = row.getByte(ordinal) + override def getField(row: InternalRow, ordinal: Int): Byte = row.getByte(ordinal) - override def copyField(from: Row, fromOrdinal: Int, to: MutableRow, toOrdinal: Int): Unit = { + override def copyField(from: InternalRow, fromOrdinal: Int, to: MutableRow, toOrdinal: Int) { to.setByte(toOrdinal, from.getByte(fromOrdinal)) } } @@ -284,7 +284,7 @@ private[sql] object SHORT extends NativeColumnType(ShortType, 6, 2) { buffer.putShort(v) } - override def append(row: Row, ordinal: Int, buffer: ByteBuffer): Unit = { + override def append(row: InternalRow, ordinal: Int, buffer: ByteBuffer): Unit = { buffer.putShort(row.getShort(ordinal)) } @@ -300,16 +300,16 @@ private[sql] object SHORT extends NativeColumnType(ShortType, 6, 2) { row.setShort(ordinal, value) } - override def getField(row: Row, ordinal: Int): Short = row.getShort(ordinal) + override def getField(row: InternalRow, ordinal: Int): Short = row.getShort(ordinal) - override def copyField(from: Row, fromOrdinal: Int, to: MutableRow, toOrdinal: Int): Unit = { + override def copyField(from: InternalRow, fromOrdinal: Int, to: MutableRow, toOrdinal: Int) { to.setShort(toOrdinal, from.getShort(fromOrdinal)) } } private[sql] object STRING extends NativeColumnType(StringType, 7, 8) { - override def actualSize(row: Row, ordinal: Int): Int = { - row.getString(ordinal).getBytes("utf-8").length + 4 + override def actualSize(row: InternalRow, ordinal: Int): Int = { + row.getAs[UTF8String](ordinal).getBytes.length + 4 } override def append(v: UTF8String, buffer: ByteBuffer): Unit = { @@ -328,11 +328,11 @@ private[sql] object STRING extends NativeColumnType(StringType, 7, 8) { row.update(ordinal, value) } - override def getField(row: Row, ordinal: Int): UTF8String = { + override def getField(row: InternalRow, ordinal: Int): UTF8String = { row(ordinal).asInstanceOf[UTF8String] } - override def copyField(from: Row, fromOrdinal: Int, to: MutableRow, toOrdinal: Int): Unit = { + override def copyField(from: InternalRow, fromOrdinal: Int, to: MutableRow, toOrdinal: Int) { to.update(toOrdinal, from(fromOrdinal)) } } @@ -346,7 +346,7 @@ private[sql] object DATE extends NativeColumnType(DateType, 8, 4) { buffer.putInt(v) } - override def getField(row: Row, ordinal: Int): Int = { + override def getField(row: InternalRow, ordinal: Int): Int = { row(ordinal).asInstanceOf[Int] } @@ -364,7 +364,7 @@ private[sql] object TIMESTAMP extends NativeColumnType(TimestampType, 9, 8) { buffer.putLong(v) } - override def getField(row: Row, ordinal: Int): Long = { + override def getField(row: InternalRow, ordinal: Int): Long = { row(ordinal).asInstanceOf[Long] } @@ -387,7 +387,7 @@ private[sql] case class FIXED_DECIMAL(precision: Int, scale: Int) buffer.putLong(v.toUnscaledLong) } - override def getField(row: Row, ordinal: Int): Decimal = { + override def getField(row: InternalRow, ordinal: Int): Decimal = { row(ordinal).asInstanceOf[Decimal] } @@ -405,7 +405,7 @@ private[sql] sealed abstract class ByteArrayColumnType[T <: DataType]( defaultSize: Int) extends ColumnType[T, Array[Byte]](typeId, defaultSize) { - override def actualSize(row: Row, ordinal: Int): Int = { + override def actualSize(row: InternalRow, ordinal: Int): Int = { getField(row, ordinal).length + 4 } @@ -426,7 +426,7 @@ private[sql] object BINARY extends ByteArrayColumnType[BinaryType.type](11, 16) row(ordinal) = value } - override def getField(row: Row, ordinal: Int): Array[Byte] = { + override def getField(row: InternalRow, ordinal: Int): Array[Byte] = { row(ordinal).asInstanceOf[Array[Byte]] } } @@ -439,7 +439,7 @@ private[sql] object GENERIC extends ByteArrayColumnType[DataType](12, 16) { row(ordinal) = SparkSqlSerializer.deserialize[Any](value) } - override def getField(row: Row, ordinal: Int): Array[Byte] = { + override def getField(row: InternalRow, ordinal: Int): Array[Byte] = { SparkSqlSerializer.serialize(row(ordinal)) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala index 761f427b8cd0d..2758fd6ac21cb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala @@ -146,7 +146,8 @@ private[sql] case class InMemoryRelation( rowCount += 1 } - val stats = InternalRow.merge(columnBuilders.map(_.columnStats.collectedStatistics) : _*) + val stats = new GenericRow(columnBuilders.map(_.columnStats.collectedStatistics) + .flatMap(_.toSeq).toArray) batchStats += stats CachedBatch(columnBuilders.map(_.build().array()), stats) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer2.scala index 15b6936acd59b..74a22353b1d27 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer2.scala @@ -26,7 +26,8 @@ import scala.reflect.ClassTag import org.apache.spark.Logging import org.apache.spark.serializer._ import org.apache.spark.sql.Row -import org.apache.spark.sql.catalyst.expressions.{GenericMutableRow, MutableRow, SpecificMutableRow} +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{MutableRow, SpecificMutableRow} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -329,7 +330,7 @@ private[sql] object SparkSqlSerializer2 { */ def createDeserializationFunction( schema: Array[DataType], - in: DataInputStream): (MutableRow) => Row = { + in: DataInputStream): (MutableRow) => InternalRow = { if (schema == null) { (mutableRow: MutableRow) => null } else { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala index 1ce150ceaf5f9..cb78164a591e7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala @@ -157,9 +157,9 @@ object EvaluatePython { } /** - * Convert Row into Java Array (for pickled into Python) + * Convert InternalRow into Java Array (for pickled into Python) */ - def rowToArray(row: Row, fields: Seq[DataType]): Array[Any] = { + def rowToArray(row: InternalRow, fields: Seq[DataType]): Array[Any] = { // TODO: this is slow! row.toSeq.zip(fields).map {case (obj, dt) => toJava(obj, dt)}.toArray } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala index 93383e5a62f11..84e2988b7768d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala @@ -23,6 +23,7 @@ import org.apache.spark.sql.catalyst.expressions.{GenericMutableRow, Cast} import org.apache.spark.sql.catalyst.plans.logical.LocalRelation import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String private[sql] object StatFunctions extends Logging { @@ -123,7 +124,7 @@ private[sql] object StatFunctions extends Logging { countsRow.setLong(distinctCol2.get(row.get(1)).get + 1, row.getLong(2)) } // the value of col1 is the first value, the rest are the counts - countsRow.setString(0, col1Item.toString) + countsRow.update(0, UTF8String.fromString(col1Item.toString)) countsRow }.toSeq val headerNames = distinctCol2.map(r => StructField(r._1.toString, LongType)).toSeq diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala index 226b143923df6..90139741a50cd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala @@ -417,7 +417,7 @@ private[sql] class JDBCRDD( case IntegerConversion => mutableRow.setInt(i, rs.getInt(pos)) case LongConversion => mutableRow.setLong(i, rs.getLong(pos)) // TODO(davies): use getBytes for better performance, if the encoding is UTF-8 - case StringConversion => mutableRow.setString(i, rs.getString(pos)) + case StringConversion => mutableRow.update(i, UTF8String.fromString(rs.getString(pos))) case TimestampConversion => val t = rs.getTimestamp(pos) if (t != null) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/jdbc.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/jdbc.scala index dd8aaf6474895..2ff15393cca9b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/jdbc.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/jdbc.scala @@ -86,7 +86,7 @@ package object jdbc { case ShortType => stmt.setInt(i + 1, row.getShort(i)) case ByteType => stmt.setInt(i + 1, row.getByte(i)) case BooleanType => stmt.setBoolean(i + 1, row.getBoolean(i)) - case StringType => stmt.setString(i + 1, row.getString(i)) + case StringType => stmt.setString(i + 1, row.getAs[String](i)) case BinaryType => stmt.setBytes(i + 1, row.getAs[Array[Byte]](i)) case TimestampType => stmt.setTimestamp(i + 1, row.getAs[java.sql.Timestamp](i)) case DateType => stmt.setDate(i + 1, row.getAs[java.sql.Date](i)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala index bba6f1ec96aa8..7bb175106742c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala @@ -58,7 +58,7 @@ private[sql] class DefaultSource extends HadoopFsRelationProvider { // NOTE: This class is instantiated and used on executor side only, no need to be serializable. private[sql] class ParquetOutputWriter(path: String, context: TaskAttemptContext) - extends OutputWriter { + extends InternalOutputWriter { private val recordWriter: RecordWriter[Void, InternalRow] = { val conf = context.getConfiguration @@ -112,7 +112,7 @@ private[sql] class ParquetOutputWriter(path: String, context: TaskAttemptContext outputFormat.getRecordWriter(context) } - override def write(row: Row): Unit = recordWriter.write(null, row.asInstanceOf[InternalRow]) + override def write(row: InternalRow): Unit = recordWriter.write(null, row) override def close(): Unit = recordWriter.close(context) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala index 3dbe6faabf453..bebba064aba95 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala @@ -44,7 +44,7 @@ private[sql] case class InsertIntoDataSource( overwrite: Boolean) extends RunnableCommand { - override def run(sqlContext: SQLContext): Seq[InternalRow] = { + override def run(sqlContext: SQLContext): Seq[Row] = { val relation = logicalRelation.relation.asInstanceOf[InsertableRelation] val data = DataFrame(sqlContext, query) // Apply the schema of the existing table to the new data. @@ -54,7 +54,7 @@ private[sql] case class InsertIntoDataSource( // Invalidate the cache. sqlContext.cacheManager.invalidateCache(logicalRelation) - Seq.empty[InternalRow] + Seq.empty[Row] } } @@ -64,7 +64,7 @@ private[sql] case class InsertIntoHadoopFsRelation( mode: SaveMode) extends RunnableCommand { - override def run(sqlContext: SQLContext): Seq[InternalRow] = { + override def run(sqlContext: SQLContext): Seq[Row] = { require( relation.paths.length == 1, s"Cannot write to multiple destinations: ${relation.paths.mkString(",")}") @@ -116,7 +116,7 @@ private[sql] case class InsertIntoHadoopFsRelation( } } - Seq.empty[InternalRow] + Seq.empty[Row] } private def insert(writerContainer: BaseWriterContainer, df: DataFrame): Unit = { @@ -143,14 +143,15 @@ private[sql] case class InsertIntoHadoopFsRelation( try { writerContainer.executorSideSetup(taskContext) - val converter = if (needsConversion) { - CatalystTypeConverters.createToScalaConverter(dataSchema).asInstanceOf[InternalRow => Row] - } else { - r: InternalRow => r.asInstanceOf[Row] - } + val converter = CatalystTypeConverters.createToScalaConverter(dataSchema) + .asInstanceOf[InternalRow => Row] while (iterator.hasNext) { - val row = converter(iterator.next()) - writerContainer.outputWriterForRow(row).write(row) + val row = iterator.next() + val converted = converter(row) + writerContainer.outputWriterForRow(converted) match { + case w: InternalOutputWriter => w.write(row) + case w: OutputWriter => w.write(converted) + } } writerContainer.commitTask() @@ -212,21 +213,21 @@ private[sql] case class InsertIntoHadoopFsRelation( val partitionProj = newProjection(codegenEnabled, partitionOutput, output) val dataProj = newProjection(codegenEnabled, dataOutput, output) - val dataConverter: InternalRow => Row = if (needsConversion) { + val dataConverter = CatalystTypeConverters.createToScalaConverter(dataSchema).asInstanceOf[InternalRow => Row] - } else { - r: InternalRow => r.asInstanceOf[Row] - } val partitionSchema = StructType.fromAttributes(partitionOutput) val partConverter: InternalRow => Row = CatalystTypeConverters.createToScalaConverter(partitionSchema) .asInstanceOf[InternalRow => Row] - while (iterator.hasNext) { val row = iterator.next() - val partitionPart = partConverter(partitionProj(row)) - val dataPart = dataConverter(dataProj(row)) - writerContainer.outputWriterForRow(partitionPart).write(dataPart) + val part = partConverter(partitionProj(row)) + val dataPart = dataProj(row) + val writer = writerContainer.outputWriterForRow(part) + writer match { + case w: InternalOutputWriter => w.write(dataPart) + case w: OutputWriter => w.write(dataConverter(dataPart)) + } } writerContainer.commitTask() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala index b7095c8ead797..e0915a85c9fa6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala @@ -26,11 +26,11 @@ import org.apache.spark.Logging import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.sql.catalyst.AbstractSparkSQLParser import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation -import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, InternalRow} +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.execution.RunnableCommand import org.apache.spark.sql.types._ -import org.apache.spark.sql.{AnalysisException, DataFrame, SQLContext, SaveMode} +import org.apache.spark.sql.{AnalysisException, DataFrame, Row, SQLContext, SaveMode} import org.apache.spark.util.Utils /** @@ -408,7 +408,7 @@ private[sql] case class CreateTempTableUsing( provider: String, options: Map[String, String]) extends RunnableCommand { - def run(sqlContext: SQLContext): Seq[InternalRow] = { + def run(sqlContext: SQLContext): Seq[Row] = { val resolved = ResolvedDataSource( sqlContext, userSpecifiedSchema, Array.empty[String], provider, options) sqlContext.registerDataFrameAsTable( @@ -425,7 +425,7 @@ private[sql] case class CreateTempTableUsingAsSelect( options: Map[String, String], query: LogicalPlan) extends RunnableCommand { - override def run(sqlContext: SQLContext): Seq[InternalRow] = { + override def run(sqlContext: SQLContext): Seq[Row] = { val df = DataFrame(sqlContext, query) val resolved = ResolvedDataSource(sqlContext, provider, partitionColumns, mode, options, df) sqlContext.registerDataFrameAsTable( @@ -438,7 +438,7 @@ private[sql] case class CreateTempTableUsingAsSelect( private[sql] case class RefreshTable(databaseName: String, tableName: String) extends RunnableCommand { - override def run(sqlContext: SQLContext): Seq[InternalRow] = { + override def run(sqlContext: SQLContext): Seq[Row] = { // Refresh the given table's metadata first. sqlContext.catalog.refreshTable(databaseName, tableName) @@ -457,7 +457,7 @@ private[sql] case class RefreshTable(databaseName: String, tableName: String) sqlContext.cacheManager.cacheQuery(df, Some(tableName)) } - Seq.empty[InternalRow] + Seq.empty[Row] } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala index 43d3507d7d2ba..634ce5a1aa2b9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala @@ -341,6 +341,16 @@ abstract class OutputWriter { def close(): Unit } +/** + * internal only + */ +abstract class InternalOutputWriter extends OutputWriter { + + def write(row: Row): Unit = throw new UnsupportedOperationException + + def write(row: InternalRow): Unit +} + /** * ::Experimental:: * A [[BaseRelation]] that provides much of the common code required for formats that store their diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLTestSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLTestSuite.scala index 5fc53f7012994..1c691e2153690 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLTestSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLTestSuite.scala @@ -19,7 +19,6 @@ package org.apache.spark.sql.sources import org.apache.spark.rdd.RDD import org.apache.spark.sql._ -import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -62,7 +61,7 @@ case class SimpleDDLScan(from: Int, to: Int, table: String)(@transient val sqlCo override def buildScan(): RDD[Row] = { sqlContext.sparkContext.parallelize(from to to).map { e => - InternalRow(UTF8String.fromString(s"people$e"), e * 2) + Row(UTF8String.fromString(s"people$e"), e * 2) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala index 48875773224c7..ef415371e6263 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala @@ -72,7 +72,7 @@ case class AllDataTypesScan( override def buildScan(): RDD[Row] = { sqlContext.sparkContext.parallelize(from to to).map { i => - InternalRow( + Row( UTF8String.fromString(s"str_$i"), s"str_$i".getBytes(), i % 2 == 0, @@ -88,9 +88,9 @@ case class AllDataTypesScan( DateUtils.fromJavaTimestamp(new Timestamp(20000 + i)), UTF8String.fromString(s"varchar_$i"), Seq(i, i + 1), - Seq(Map(UTF8String.fromString(s"str_$i") -> InternalRow(i.toLong))), + Seq(Map(UTF8String.fromString(s"str_$i") -> Row(i.toLong))), Map(i -> i.toString), - Map(Map(UTF8String.fromString(s"str_$i") -> i.toFloat) -> InternalRow(i.toLong)), + Map(Map(UTF8String.fromString(s"str_$i") -> i.toFloat) -> Row(i.toLong)), Row(i, i.toString), Row(Seq(UTF8String.fromString(s"str_$i"), UTF8String.fromString(s"str_${i + 1}")), InternalRow(Seq(DateUtils.fromJavaDate(new Date(1970, 1, i + 1)))))) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala index 485810320f3c1..c790ba29bb9ba 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala @@ -30,6 +30,7 @@ import org.apache.hadoop.hive.serde2.objectinspector.{ObjectInspectorConverters, import org.apache.hadoop.io.Writable import org.apache.hadoop.mapred.{FileInputFormat, InputFormat, JobConf} +import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.{Logging, SerializableWritable} import org.apache.spark.broadcast.Broadcast import org.apache.spark.rdd.{EmptyRDD, HadoopRDD, RDD, UnionRDD} @@ -357,7 +358,7 @@ private[hive] object HadoopTableReader extends HiveInspectors with Logging { (value: Any, row: MutableRow, ordinal: Int) => row.setDouble(ordinal, oi.get(value)) case oi: HiveVarcharObjectInspector => (value: Any, row: MutableRow, ordinal: Int) => - row.setString(ordinal, oi.getPrimitiveJavaObject(value).getValue) + row.update(ordinal, UTF8String.fromString(oi.getPrimitiveJavaObject(value).getValue)) case oi: HiveDecimalObjectInspector => (value: Any, row: MutableRow, ordinal: Int) => row.update(ordinal, HiveShim.toCatalystDecimal(oi, value)) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateTableAsSelect.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateTableAsSelect.scala index 0e4a2427a9c15..84358cb73c9e3 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateTableAsSelect.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateTableAsSelect.scala @@ -17,13 +17,11 @@ package org.apache.spark.sql.hive.execution -import org.apache.spark.annotation.Experimental -import org.apache.spark.sql.{AnalysisException, SQLContext} -import org.apache.spark.sql.catalyst.expressions.InternalRow import org.apache.spark.sql.catalyst.plans.logical.{InsertIntoTable, LogicalPlan} import org.apache.spark.sql.execution.RunnableCommand -import org.apache.spark.sql.hive.client.{HiveTable, HiveColumn} -import org.apache.spark.sql.hive.{HiveContext, MetastoreRelation, HiveMetastoreTypes} +import org.apache.spark.sql.hive.client.{HiveColumn, HiveTable} +import org.apache.spark.sql.hive.{HiveContext, HiveMetastoreTypes, MetastoreRelation} +import org.apache.spark.sql.{AnalysisException, Row, SQLContext} /** * Create table and insert the query result into it. @@ -42,11 +40,11 @@ case class CreateTableAsSelect( def database: String = tableDesc.database def tableName: String = tableDesc.name - override def run(sqlContext: SQLContext): Seq[InternalRow] = { + override def run(sqlContext: SQLContext): Seq[Row] = { val hiveContext = sqlContext.asInstanceOf[HiveContext] lazy val metastoreRelation: MetastoreRelation = { - import org.apache.hadoop.hive.serde2.`lazy`.LazySimpleSerDe import org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat + import org.apache.hadoop.hive.serde2.`lazy`.LazySimpleSerDe import org.apache.hadoop.io.Text import org.apache.hadoop.mapred.TextInputFormat @@ -89,7 +87,7 @@ case class CreateTableAsSelect( hiveContext.executePlan(InsertIntoTable(metastoreRelation, Map(), query, true, false)).toRdd } - Seq.empty[InternalRow] + Seq.empty[Row] } override def argString: String = { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/DescribeHiveTableCommand.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/DescribeHiveTableCommand.scala index a89381000ad5f..5f0ed5393d191 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/DescribeHiveTableCommand.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/DescribeHiveTableCommand.scala @@ -21,10 +21,10 @@ import scala.collection.JavaConversions._ import org.apache.hadoop.hive.metastore.api.FieldSchema -import org.apache.spark.sql.SQLContext -import org.apache.spark.sql.catalyst.expressions.{Attribute, InternalRow} +import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.execution.RunnableCommand import org.apache.spark.sql.hive.MetastoreRelation +import org.apache.spark.sql.{Row, SQLContext} /** * Implementation for "describe [extended] table". @@ -35,7 +35,7 @@ case class DescribeHiveTableCommand( override val output: Seq[Attribute], isExtended: Boolean) extends RunnableCommand { - override def run(sqlContext: SQLContext): Seq[InternalRow] = { + override def run(sqlContext: SQLContext): Seq[Row] = { // Trying to mimic the format of Hive's output. But not exactly the same. var results: Seq[(String, String, String)] = Nil @@ -57,7 +57,7 @@ case class DescribeHiveTableCommand( } results.map { case (name, dataType, comment) => - InternalRow(name, dataType, comment) + Row(name, dataType, comment) } } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveNativeCommand.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveNativeCommand.scala index 87f8e3f7fcfcc..41b645b2c9c93 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveNativeCommand.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveNativeCommand.scala @@ -17,11 +17,11 @@ package org.apache.spark.sql.hive.execution -import org.apache.spark.sql.catalyst.expressions.{AttributeReference, InternalRow} +import org.apache.spark.sql.catalyst.expressions.AttributeReference import org.apache.spark.sql.execution.RunnableCommand import org.apache.spark.sql.hive.HiveContext -import org.apache.spark.sql.SQLContext import org.apache.spark.sql.types.StringType +import org.apache.spark.sql.{Row, SQLContext} private[hive] case class HiveNativeCommand(sql: String) extends RunnableCommand { @@ -29,6 +29,6 @@ case class HiveNativeCommand(sql: String) extends RunnableCommand { override def output: Seq[AttributeReference] = Seq(AttributeReference("result", StringType, nullable = false)()) - override def run(sqlContext: SQLContext): Seq[InternalRow] = - sqlContext.asInstanceOf[HiveContext].runSqlHive(sql).map(InternalRow(_)) + override def run(sqlContext: SQLContext): Seq[Row] = + sqlContext.asInstanceOf[HiveContext].runSqlHive(sql).map(Row(_)) } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala index 1d306c5d10af8..3f20b187b0ded 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala @@ -120,7 +120,7 @@ case class InsertIntoHiveTable( * * Note: this is run once and then kept to avoid double insertions. */ - protected[sql] lazy val sideEffectResult: Seq[InternalRow] = { + protected[sql] lazy val sideEffectResult: Seq[Row] = { // Have to pass the TableDesc object to RDD.mapPartitions and then instantiate new serializer // instances within the closure, since Serializer is not serializable while TableDesc is. val tableDesc = table.tableDesc @@ -251,7 +251,7 @@ case class InsertIntoHiveTable( // however for now we return an empty list to simplify compatibility checks with hive, which // does not return anything for insert operations. // TODO: implement hive compatibility as rules. - Seq.empty[InternalRow] + Seq.empty[Row] } override def executeCollect(): Array[Row] = diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala index 195e5752c3ec0..0a909b695755e 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala @@ -17,15 +17,14 @@ package org.apache.spark.sql.hive.execution -import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.analysis.EliminateSubQueries -import org.apache.spark.sql.catalyst.util._ -import org.apache.spark.sql.sources._ -import org.apache.spark.sql.{SaveMode, DataFrame, SQLContext} -import org.apache.spark.sql.catalyst.expressions.{Attribute, InternalRow} +import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.execution.RunnableCommand import org.apache.spark.sql.hive.HiveContext +import org.apache.spark.sql.sources._ import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -39,9 +38,9 @@ import org.apache.spark.util.Utils private[hive] case class AnalyzeTable(tableName: String) extends RunnableCommand { - override def run(sqlContext: SQLContext): Seq[InternalRow] = { + override def run(sqlContext: SQLContext): Seq[Row] = { sqlContext.asInstanceOf[HiveContext].analyze(tableName) - Seq.empty[InternalRow] + Seq.empty[Row] } } @@ -53,7 +52,7 @@ case class DropTable( tableName: String, ifExists: Boolean) extends RunnableCommand { - override def run(sqlContext: SQLContext): Seq[InternalRow] = { + override def run(sqlContext: SQLContext): Seq[Row] = { val hiveContext = sqlContext.asInstanceOf[HiveContext] val ifExistsClause = if (ifExists) "IF EXISTS " else "" try { @@ -70,7 +69,7 @@ case class DropTable( hiveContext.invalidateTable(tableName) hiveContext.runSqlHive(s"DROP TABLE $ifExistsClause$tableName") hiveContext.catalog.unregisterTable(Seq(tableName)) - Seq.empty[InternalRow] + Seq.empty[Row] } } @@ -83,7 +82,7 @@ case class AddJar(path: String) extends RunnableCommand { schema.toAttributes } - override def run(sqlContext: SQLContext): Seq[InternalRow] = { + override def run(sqlContext: SQLContext): Seq[Row] = { val hiveContext = sqlContext.asInstanceOf[HiveContext] val currentClassLoader = Utils.getContextOrSparkClassLoader @@ -99,18 +98,18 @@ case class AddJar(path: String) extends RunnableCommand { // Add jar to executors hiveContext.sparkContext.addJar(path) - Seq(InternalRow(0)) + Seq(Row(0)) } } private[hive] case class AddFile(path: String) extends RunnableCommand { - override def run(sqlContext: SQLContext): Seq[InternalRow] = { + override def run(sqlContext: SQLContext): Seq[Row] = { val hiveContext = sqlContext.asInstanceOf[HiveContext] hiveContext.runSqlHive(s"ADD FILE $path") hiveContext.sparkContext.addFile(path) - Seq.empty[InternalRow] + Seq.empty[Row] } } @@ -123,12 +122,12 @@ case class CreateMetastoreDataSource( allowExisting: Boolean, managedIfNoPath: Boolean) extends RunnableCommand { - override def run(sqlContext: SQLContext): Seq[InternalRow] = { + override def run(sqlContext: SQLContext): Seq[Row] = { val hiveContext = sqlContext.asInstanceOf[HiveContext] if (hiveContext.catalog.tableExists(tableName :: Nil)) { if (allowExisting) { - return Seq.empty[InternalRow] + return Seq.empty[Row] } else { throw new AnalysisException(s"Table $tableName already exists.") } @@ -151,7 +150,7 @@ case class CreateMetastoreDataSource( optionsWithPath, isExternal) - Seq.empty[InternalRow] + Seq.empty[Row] } } @@ -164,7 +163,7 @@ case class CreateMetastoreDataSourceAsSelect( options: Map[String, String], query: LogicalPlan) extends RunnableCommand { - override def run(sqlContext: SQLContext): Seq[InternalRow] = { + override def run(sqlContext: SQLContext): Seq[Row] = { val hiveContext = sqlContext.asInstanceOf[HiveContext] var createMetastoreTable = false var isExternal = true @@ -188,7 +187,7 @@ case class CreateMetastoreDataSourceAsSelect( s"Or, if you are using SQL CREATE TABLE, you need to drop $tableName first.") case SaveMode.Ignore => // Since the table already exists and the save mode is Ignore, we will just return. - return Seq.empty[InternalRow] + return Seq.empty[Row] case SaveMode.Append => // Check if the specified data source match the data source of the existing table. val resolved = ResolvedDataSource( @@ -253,6 +252,6 @@ case class CreateMetastoreDataSourceAsSelect( // Refresh the cache of the table in the catalog. hiveContext.refreshTable(tableName) - Seq.empty[InternalRow] + Seq.empty[Row] } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala index ee440e304ec19..aadf4dc230d93 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala @@ -23,20 +23,20 @@ import java.util.Date import scala.collection.mutable import org.apache.hadoop.fs.Path +import org.apache.hadoop.hive.common.FileUtils import org.apache.hadoop.hive.conf.HiveConf.ConfVars import org.apache.hadoop.hive.ql.exec.{FileSinkOperator, Utilities} import org.apache.hadoop.hive.ql.io.{HiveFileFormatUtils, HiveOutputFormat} import org.apache.hadoop.hive.ql.plan.{PlanUtils, TableDesc} import org.apache.hadoop.io.Writable import org.apache.hadoop.mapred._ -import org.apache.hadoop.hive.common.FileUtils import org.apache.spark.mapred.SparkHadoopMapRedUtil -import org.apache.spark.sql.Row -import org.apache.spark.{Logging, SerializableWritable, SparkHadoopWriter} +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.util.DateUtils import org.apache.spark.sql.hive.HiveShim.{ShimFileSinkDesc => FileSinkDesc} import org.apache.spark.sql.types._ +import org.apache.spark.{Logging, SerializableWritable, SparkHadoopWriter} /** * Internal helper class that saves an RDD using a Hive OutputFormat. @@ -93,7 +93,8 @@ private[hive] class SparkHiveWriterContainer( "part-" + numberFormat.format(splitID) + extension } - def getLocalFileWriter(row: Row, schema: StructType): FileSinkOperator.RecordWriter = writer + def getLocalFileWriter(row: InternalRow, schema: StructType): FileSinkOperator.RecordWriter = + writer def close() { // Seems the boolean value passed into close does not matter. @@ -164,7 +165,7 @@ private[spark] class SparkHiveDynamicPartitionWriterContainer( dynamicPartColNames: Array[String]) extends SparkHiveWriterContainer(jobConf, fileSinkConf) { - import SparkHiveDynamicPartitionWriterContainer._ + import org.apache.spark.sql.hive.SparkHiveDynamicPartitionWriterContainer._ private val defaultPartName = jobConf.get( ConfVars.DEFAULTPARTITIONNAME.varname, ConfVars.DEFAULTPARTITIONNAME.defaultVal) @@ -196,7 +197,9 @@ private[spark] class SparkHiveDynamicPartitionWriterContainer( jobConf.setBoolean(SUCCESSFUL_JOB_OUTPUT_DIR_MARKER, oldMarker) } - override def getLocalFileWriter(row: Row, schema: StructType): FileSinkOperator.RecordWriter = { + override def getLocalFileWriter( + row: InternalRow, + schema: StructType): FileSinkOperator.RecordWriter = { def convertToHiveRawString(col: String, value: Any): String = { val raw = String.valueOf(value) schema(col).dataType match { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala index f03c4cd54e7e6..efd4c2b0b89fe 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala @@ -63,7 +63,7 @@ private[orc] class OrcOutputWriter( path: String, dataSchema: StructType, context: TaskAttemptContext) - extends OutputWriter with SparkHadoopMapRedUtil with HiveInspectors { + extends InternalOutputWriter with SparkHadoopMapRedUtil with HiveInspectors { private val serializer = { val table = new Properties() @@ -115,7 +115,7 @@ private[orc] class OrcOutputWriter( ).asInstanceOf[RecordWriter[NullWritable, Writable]] } - override def write(row: Row): Unit = { + override def write(row: InternalRow): Unit = { var i = 0 while (i < row.length) { reusableOutputBuffer(i) = wrappers(i)(row(i)) @@ -188,7 +188,7 @@ private[sql] class OrcRelation( filters: Array[Filter], inputPaths: Array[FileStatus]): RDD[Row] = { val output = StructType(requiredColumns.map(dataSchema(_))).toAttributes - OrcTableScan(output, this, filters, inputPaths).execute() + OrcTableScan(output, this, filters, inputPaths).execute().map(_.asInstanceOf[Row]) } override def prepareJobForWrite(job: Job): OutputWriterFactory = { @@ -222,13 +222,13 @@ private[orc] case class OrcTableScan( HiveShim.appendReadColumns(conf, sortedIds, sortedNames) } - // Transform all given raw `Writable`s into `Row`s. + // Transform all given raw `Writable`s into `InternalRow`s. private def fillObject( path: String, conf: Configuration, iterator: Iterator[Writable], nonPartitionKeyAttrs: Seq[(Attribute, Int)], - mutableRow: MutableRow): Iterator[Row] = { + mutableRow: MutableRow): Iterator[InternalRow] = { val deserializer = new OrcSerde val soi = OrcFileOperator.getObjectInspector(path, Some(conf)) val (fieldRefs, fieldOrdinals) = nonPartitionKeyAttrs.map { @@ -249,11 +249,11 @@ private[orc] case class OrcTableScan( } i += 1 } - mutableRow: Row + mutableRow: InternalRow } } - def execute(): RDD[Row] = { + def execute(): RDD[InternalRow] = { val job = new Job(sqlContext.sparkContext.hadoopConfiguration) val conf = job.getConfiguration