diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/BaseMutableRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/BaseMutableRow.java deleted file mode 100644 index acec2bf4520f..000000000000 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/BaseMutableRow.java +++ /dev/null @@ -1,68 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql; - -import org.apache.spark.sql.catalyst.expressions.MutableRow; - -public abstract class BaseMutableRow extends BaseRow implements MutableRow { - - @Override - public void update(int ordinal, Object value) { - throw new UnsupportedOperationException(); - } - - @Override - public void setInt(int ordinal, int value) { - throw new UnsupportedOperationException(); - } - - @Override - public void setLong(int ordinal, long value) { - throw new UnsupportedOperationException(); - } - - @Override - public void setDouble(int ordinal, double value) { - throw new UnsupportedOperationException(); - } - - @Override - public void setBoolean(int ordinal, boolean value) { - throw new UnsupportedOperationException(); - } - - @Override - public void setShort(int ordinal, short value) { - throw new UnsupportedOperationException(); - } - - @Override - public void setByte(int ordinal, byte value) { - throw new UnsupportedOperationException(); - } - - @Override - public void setFloat(int ordinal, float value) { - throw new UnsupportedOperationException(); - } - - @Override - public void setString(int ordinal, String value) { - throw new UnsupportedOperationException(); - } -} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/BaseRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/BaseRow.java deleted file mode 100644 index 6a2356f1f9c6..000000000000 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/BaseRow.java +++ /dev/null @@ -1,197 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql; - -import java.math.BigDecimal; -import java.sql.Date; -import java.sql.Timestamp; -import java.util.List; - -import scala.collection.Seq; -import scala.collection.mutable.ArraySeq; - -import org.apache.spark.sql.catalyst.InternalRow; -import org.apache.spark.sql.catalyst.expressions.GenericRow; -import org.apache.spark.sql.types.StructType; - -public abstract class BaseRow extends InternalRow { - - @Override - final public int length() { - return size(); - } - - @Override - public boolean anyNull() { - final int n = size(); - for (int i=0; i < n; i++) { - if (isNullAt(i)) { - return true; - } - } - return false; - } - - @Override - public StructType schema() { throw new UnsupportedOperationException(); } - - @Override - final public Object apply(int i) { - return get(i); - } - - @Override - public int getInt(int i) { - throw new UnsupportedOperationException(); - } - - @Override - public long getLong(int i) { - throw new UnsupportedOperationException(); - } - - @Override - public float getFloat(int i) { - throw new UnsupportedOperationException(); - } - - @Override - public double getDouble(int i) { - throw new UnsupportedOperationException(); - } - - @Override - public byte getByte(int i) { - throw new UnsupportedOperationException(); - } - - @Override - public short getShort(int i) { - throw new UnsupportedOperationException(); - } - - @Override - public boolean getBoolean(int i) { - throw new UnsupportedOperationException(); - } - - @Override - public String getString(int i) { - throw new UnsupportedOperationException(); - } - - @Override - public BigDecimal getDecimal(int i) { - throw new UnsupportedOperationException(); - } - - @Override - public Date getDate(int i) { - throw new UnsupportedOperationException(); - } - - @Override - public Timestamp getTimestamp(int i) { - throw new UnsupportedOperationException(); - } - - @Override - public Seq getSeq(int i) { - throw new UnsupportedOperationException(); - } - - @Override - public List getList(int i) { - throw new UnsupportedOperationException(); - } - - @Override - public scala.collection.Map getMap(int i) { - throw new UnsupportedOperationException(); - } - - @Override - public scala.collection.immutable.Map getValuesMap(Seq fieldNames) { - throw new UnsupportedOperationException(); - } - - @Override - public java.util.Map getJavaMap(int i) { - throw new UnsupportedOperationException(); - } - - @Override - public Row getStruct(int i) { - throw new UnsupportedOperationException(); - } - - @Override - public T getAs(int i) { - throw new UnsupportedOperationException(); - } - - @Override - public T getAs(String fieldName) { - throw new UnsupportedOperationException(); - } - - @Override - public int fieldIndex(String name) { - throw new UnsupportedOperationException(); - } - - @Override - public InternalRow copy() { - final int n = size(); - Object[] arr = new Object[n]; - for (int i = 0; i < n; i++) { - arr[i] = get(i); - } - return new GenericRow(arr); - } - - @Override - public Seq toSeq() { - final int n = size(); - final ArraySeq values = new ArraySeq(n); - for (int i = 0; i < n; i++) { - values.update(i, get(i)); - } - return values; - } - - @Override - public String toString() { - return mkString("[", ",", "]"); - } - - @Override - public String mkString() { - return toSeq().mkString(); - } - - @Override - public String mkString(String sep) { - return toSeq().mkString(sep); - } - - @Override - public String mkString(String start, String sep, String end) { - return toSeq().mkString(start, sep, end); - } -} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java index bb2f2079b40f..705512146b2d 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java @@ -27,7 +27,6 @@ import scala.collection.mutable.ArraySeq; import org.apache.spark.sql.catalyst.InternalRow; -import org.apache.spark.sql.BaseMutableRow; import org.apache.spark.sql.types.DataType; import org.apache.spark.sql.types.StructType; import org.apache.spark.unsafe.types.UTF8String; @@ -52,7 +51,7 @@ * * Instances of `UnsafeRow` act as pointers to row data stored in this format. */ -public final class UnsafeRow extends BaseMutableRow { +public final class UnsafeRow extends MutableRow { private Object baseObject; private long baseOffset; @@ -176,56 +175,56 @@ public void update(int ordinal, Object value) { } @Override - public void setInt(int ordinal, int value) { + public void setBoolean(int ordinal, boolean value) { assertIndexIsValid(ordinal); setNotNullAt(ordinal); - PlatformDependent.UNSAFE.putInt(baseObject, getFieldOffset(ordinal), value); + PlatformDependent.UNSAFE.putBoolean(baseObject, getFieldOffset(ordinal), value); } @Override - public void setLong(int ordinal, long value) { + public void setByte(int ordinal, byte value) { assertIndexIsValid(ordinal); setNotNullAt(ordinal); - PlatformDependent.UNSAFE.putLong(baseObject, getFieldOffset(ordinal), value); + PlatformDependent.UNSAFE.putByte(baseObject, getFieldOffset(ordinal), value); } @Override - public void setDouble(int ordinal, double value) { + public void setShort(int ordinal, short value) { assertIndexIsValid(ordinal); setNotNullAt(ordinal); - PlatformDependent.UNSAFE.putDouble(baseObject, getFieldOffset(ordinal), value); + PlatformDependent.UNSAFE.putShort(baseObject, getFieldOffset(ordinal), value); } @Override - public void setBoolean(int ordinal, boolean value) { + public void setInt(int ordinal, int value) { assertIndexIsValid(ordinal); setNotNullAt(ordinal); - PlatformDependent.UNSAFE.putBoolean(baseObject, getFieldOffset(ordinal), value); + PlatformDependent.UNSAFE.putInt(baseObject, getFieldOffset(ordinal), value); } @Override - public void setShort(int ordinal, short value) { + public void setLong(int ordinal, long value) { assertIndexIsValid(ordinal); setNotNullAt(ordinal); - PlatformDependent.UNSAFE.putShort(baseObject, getFieldOffset(ordinal), value); + PlatformDependent.UNSAFE.putLong(baseObject, getFieldOffset(ordinal), value); } @Override - public void setByte(int ordinal, byte value) { + public void setFloat(int ordinal, float value) { assertIndexIsValid(ordinal); setNotNullAt(ordinal); - PlatformDependent.UNSAFE.putByte(baseObject, getFieldOffset(ordinal), value); + PlatformDependent.UNSAFE.putFloat(baseObject, getFieldOffset(ordinal), value); } @Override - public void setFloat(int ordinal, float value) { + public void setDouble(int ordinal, double value) { assertIndexIsValid(ordinal); setNotNullAt(ordinal); - PlatformDependent.UNSAFE.putFloat(baseObject, getFieldOffset(ordinal), value); + PlatformDependent.UNSAFE.putDouble(baseObject, getFieldOffset(ordinal), value); } @Override - public int size() { + public int length() { return numFields; } @@ -235,7 +234,7 @@ public StructType schema() { } @Override - public Object get(int i) { + public Object apply(int i) { assertIndexIsValid(i); assert (schema != null) : "Schema must be defined when calling generic get() method"; final DataType dataType = schema.fields()[i].dataType(); diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala index e99d5c87a44f..11ed2bc78c07 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala @@ -251,28 +251,28 @@ trait Row extends Serializable { * * @throws ClassCastException when data type does not match. */ - def getDecimal(i: Int): java.math.BigDecimal = apply(i).asInstanceOf[java.math.BigDecimal] + def getDecimal(i: Int): java.math.BigDecimal = getAs[java.math.BigDecimal](i) /** * Returns the value at position i of date type as java.sql.Date. * * @throws ClassCastException when data type does not match. */ - def getDate(i: Int): java.sql.Date = apply(i).asInstanceOf[java.sql.Date] + def getDate(i: Int): java.sql.Date = getAs[java.sql.Date](i) /** * Returns the value at position i of date type as java.sql.Timestamp. * * @throws ClassCastException when data type does not match. */ - def getTimestamp(i: Int): java.sql.Timestamp = apply(i).asInstanceOf[java.sql.Timestamp] + def getTimestamp(i: Int): java.sql.Timestamp = getAs[java.sql.Timestamp](i) /** * Returns the value at position i of array type as a Scala Seq. * * @throws ClassCastException when data type does not match. */ - def getSeq[T](i: Int): Seq[T] = apply(i).asInstanceOf[Seq[T]] + def getSeq[T](i: Int): Seq[T] = getAs[Seq[T]](i) /** * Returns the value at position i of array type as [[java.util.List]]. @@ -288,7 +288,7 @@ trait Row extends Serializable { * * @throws ClassCastException when data type does not match. */ - def getMap[K, V](i: Int): scala.collection.Map[K, V] = apply(i).asInstanceOf[Map[K, V]] + def getMap[K, V](i: Int): scala.collection.Map[K, V] = getAs[Map[K, V]](i) /** * Returns the value at position i of array type as a [[java.util.Map]]. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala index d7b537a9fe3b..23fa6b535078 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala @@ -25,9 +25,31 @@ import org.apache.spark.sql.catalyst.expressions._ * internal types. */ abstract class InternalRow extends Row { + + // default implementation for codegen (for a Row which does not have those types) + override def getBoolean(i: Int): Boolean = throw new UnsupportedOperationException + override def getByte(i: Int): Byte = throw new UnsupportedOperationException + override def getShort(i: Int): Short = throw new UnsupportedOperationException + override def getInt(i: Int): Int = throw new UnsupportedOperationException + override def getLong(i: Int): Long = throw new UnsupportedOperationException + override def getFloat(i: Int): Float = throw new UnsupportedOperationException + override def getDouble(i: Int): Double = throw new UnsupportedOperationException + override def getString(i: Int): String = throw new UnsupportedOperationException + // A default implementation to change the return type override def copy(): InternalRow = this + def toSeq(): Seq[Any] = { + val n = length + val values = new Array[Any](n) + var i = 0 + while (i < n) { + values(i) = apply(i) + i += 1 + } + values + } + override def equals(o: Any): Boolean = { if (!o.isInstanceOf[Row]) { return false diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala index 05aab3455998..46c88ae28718 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala @@ -241,31 +241,8 @@ final class SpecificMutableRow(val values: Array[MutableValue]) extends MutableR } } - override def setString(ordinal: Int, value: String): Unit = - update(ordinal, UTF8String.fromString(value)) - override def getString(ordinal: Int): String = apply(ordinal).toString - override def setInt(ordinal: Int, value: Int): Unit = { - val currentValue = values(ordinal).asInstanceOf[MutableInt] - currentValue.isNull = false - currentValue.value = value - } - - override def getInt(i: Int): Int = { - values(i).asInstanceOf[MutableInt].value - } - - override def setFloat(ordinal: Int, value: Float): Unit = { - val currentValue = values(ordinal).asInstanceOf[MutableFloat] - currentValue.isNull = false - currentValue.value = value - } - - override def getFloat(i: Int): Float = { - values(i).asInstanceOf[MutableFloat].value - } - override def setBoolean(ordinal: Int, value: Boolean): Unit = { val currentValue = values(ordinal).asInstanceOf[MutableBoolean] currentValue.isNull = false @@ -276,14 +253,15 @@ final class SpecificMutableRow(val values: Array[MutableValue]) extends MutableR values(i).asInstanceOf[MutableBoolean].value } - override def setDouble(ordinal: Int, value: Double): Unit = { - val currentValue = values(ordinal).asInstanceOf[MutableDouble] + + override def setByte(ordinal: Int, value: Byte): Unit = { + val currentValue = values(ordinal).asInstanceOf[MutableByte] currentValue.isNull = false currentValue.value = value } - override def getDouble(i: Int): Double = { - values(i).asInstanceOf[MutableDouble].value + override def getByte(i: Int): Byte = { + values(i).asInstanceOf[MutableByte].value } override def setShort(ordinal: Int, value: Short): Unit = { @@ -296,6 +274,16 @@ final class SpecificMutableRow(val values: Array[MutableValue]) extends MutableR values(i).asInstanceOf[MutableShort].value } + override def setInt(ordinal: Int, value: Int): Unit = { + val currentValue = values(ordinal).asInstanceOf[MutableInt] + currentValue.isNull = false + currentValue.value = value + } + + override def getInt(i: Int): Int = { + values(i).asInstanceOf[MutableInt].value + } + override def setLong(ordinal: Int, value: Long): Unit = { val currentValue = values(ordinal).asInstanceOf[MutableLong] currentValue.isNull = false @@ -306,17 +294,23 @@ final class SpecificMutableRow(val values: Array[MutableValue]) extends MutableR values(i).asInstanceOf[MutableLong].value } - override def setByte(ordinal: Int, value: Byte): Unit = { - val currentValue = values(ordinal).asInstanceOf[MutableByte] + override def setFloat(ordinal: Int, value: Float): Unit = { + val currentValue = values(ordinal).asInstanceOf[MutableFloat] currentValue.isNull = false currentValue.value = value } - override def getByte(i: Int): Byte = { - values(i).asInstanceOf[MutableByte].value + override def getFloat(i: Int): Float = { + values(i).asInstanceOf[MutableFloat].value + } + + override def setDouble(ordinal: Int, value: Double): Unit = { + val currentValue = values(ordinal).asInstanceOf[MutableDouble] + currentValue.isNull = false + currentValue.value = value } - override def getAs[T](i: Int): T = { - values(i).boxed.asInstanceOf[T] + override def getDouble(i: Int): Double = { + values(i).asInstanceOf[MutableDouble].value } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala index e75e82d38054..301a7ec1269f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.expressions.codegen import org.apache.spark.sql.catalyst.expressions._ // MutableProjection is not accessible in Java -abstract class BaseMutableProjection extends MutableProjection {} +abstract class BaseMutableProjection extends MutableProjection /** * Generates byte code that produces a [[MutableRow]] object that can update itself based on a new @@ -46,12 +46,14 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], () => Mu mutableRow.${ctx.setColumn(e.dataType, i, evaluationCode.primitive)}; """ }.mkString("\n") + + val mutableProjectionClass = classOf[BaseMutableProjection].getName val code = s""" public Object generate($exprType[] expr) { return new SpecificProjection(expr); } - class SpecificProjection extends ${classOf[BaseMutableProjection].getName} { + class SpecificProjection extends $mutableProjectionClass { private $exprType[] expressions = null; private $mutableRowType mutableRow = null; @@ -61,14 +63,14 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], () => Mu mutableRow = new $genericMutableRowType(${expressions.size}); } - public ${classOf[BaseMutableProjection].getName} target($mutableRowType row) { + public $mutableProjectionClass target($mutableRowType row) { mutableRow = row; return this; } /* Provide immutable access to the last projected row. */ public InternalRow currentValue() { - return (InternalRow) mutableRow; + return mutableRow; } public Object apply(Object _i) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala index 624e1cf4e201..61d7b60e56c0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala @@ -17,7 +17,6 @@ package org.apache.spark.sql.catalyst.expressions.codegen -import org.apache.spark.sql.BaseMutableRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.types._ @@ -33,7 +32,6 @@ abstract class BaseProject extends Projection {} * primitive values. */ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] { - import scala.reflect.runtime.universe._ protected def canonicalize(in: Seq[Expression]): Seq[Expression] = in.map(ExpressionCanonicalizer.execute) @@ -154,7 +152,7 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] { return new SpecificProjection(expr); } - class SpecificProjection extends ${typeOf[BaseProject]} { + class SpecificProjection extends ${classOf[BaseProject].getName} { private $exprType[] expressions = null; public SpecificProjection($exprType[] expr) { @@ -167,7 +165,7 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] { } } - final class SpecificRow extends ${typeOf[BaseMutableRow]} { + final class SpecificRow extends $mutableRowType { $columns @@ -175,12 +173,12 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] { $initColumns } - public int size() { return ${expressions.length};} + public int length() { return ${expressions.length};} protected boolean[] nullBits = new boolean[${expressions.length}]; public void setNullAt(int i) { nullBits[i] = true; } public boolean isNullAt(int i) { return nullBits[i]; } - public Object get(int i) { + public Object apply(int i) { if (isNullAt(i)) return null; switch (i) { $getCases diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala index 0d4c9ace5e12..268242e0a07e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst.expressions +import org.apache.spark.sql.Row import org.apache.spark.sql.types.{DataType, StructType, AtomicType} import org.apache.spark.unsafe.types.UTF8String @@ -24,51 +25,34 @@ import org.apache.spark.unsafe.types.UTF8String * An extended interface to [[InternalRow]] that allows the values for each column to be updated. * Setting a value through a primitive function implicitly marks that column as not null. */ -trait MutableRow extends InternalRow { +abstract class MutableRow extends InternalRow { def setNullAt(i: Int): Unit - def update(ordinal: Int, value: Any) + def update(ordinal: Int, value: Any): Unit - def setInt(ordinal: Int, value: Int) - def setLong(ordinal: Int, value: Long) - def setDouble(ordinal: Int, value: Double) - def setBoolean(ordinal: Int, value: Boolean) - def setShort(ordinal: Int, value: Short) - def setByte(ordinal: Int, value: Byte) - def setFloat(ordinal: Int, value: Float) - def setString(ordinal: Int, value: String) -} + // default implementation for codegen (for a Row which does not have those types) + def setBoolean(ordinal: Int, value: Boolean): Unit = throw new UnsupportedOperationException + def setByte(ordinal: Int, value: Byte): Unit = throw new UnsupportedOperationException + def setShort(ordinal: Int, value: Short): Unit = throw new UnsupportedOperationException + def setInt(ordinal: Int, value: Int): Unit = throw new UnsupportedOperationException + def setLong(ordinal: Int, value: Long): Unit = throw new UnsupportedOperationException + def setFloat(ordinal: Int, value: Float): Unit = throw new UnsupportedOperationException + def setDouble(ordinal: Int, value: Double): Unit = throw new UnsupportedOperationException -/** - * A row with no data. Calling any methods will result in an error. Can be used as a placeholder. - */ -object EmptyRow extends InternalRow { - override def apply(i: Int): Any = throw new UnsupportedOperationException - override def toSeq: Seq[Any] = Seq.empty - override def length: Int = 0 - override def isNullAt(i: Int): Boolean = throw new UnsupportedOperationException - override def getInt(i: Int): Int = throw new UnsupportedOperationException - override def getLong(i: Int): Long = throw new UnsupportedOperationException - override def getDouble(i: Int): Double = throw new UnsupportedOperationException - override def getFloat(i: Int): Float = throw new UnsupportedOperationException - override def getBoolean(i: Int): Boolean = throw new UnsupportedOperationException - override def getShort(i: Int): Short = throw new UnsupportedOperationException - override def getByte(i: Int): Byte = throw new UnsupportedOperationException - override def getString(i: Int): String = throw new UnsupportedOperationException - override def getAs[T](i: Int): T = throw new UnsupportedOperationException - override def copy(): InternalRow = this + def setString(ordinal: Int, value: String): Unit = { + update(ordinal, UTF8String.fromString(value)) + } } /** - * A row implementation that uses an array of objects as the underlying storage. Note that, while - * the array is not copied, and thus could technically be mutated after creation, this is not - * allowed. + * A general row implementation that uses an array of objects as the underlying storage. + * Note that, while the array is not copied, and thus could technically be mutated after + * creation, this is not allowed. */ -class GenericRow(protected[sql] val values: Array[Any]) extends InternalRow { - /** No-arg constructor for serialization. */ - protected def this() = this(null) +trait ArrayBackedRow { + self: Row => - def this(size: Int) = this(new Array[Any](size)) + protected val values: Array[Any] override def toSeq: Seq[Any] = values.toSeq @@ -78,40 +62,71 @@ class GenericRow(protected[sql] val values: Array[Any]) extends InternalRow { override def isNullAt(i: Int): Boolean = values(i) == null + override def getBoolean(i: Int): Boolean = { + if (values(i) == null) sys.error("Failed to check null bit for primitive boolean value.") + getAs[Boolean](i) + } + + override def getByte(i: Int): Byte = { + if (values(i) == null) sys.error("Failed to check null bit for primitive byte value.") + getAs[Byte](i) + } + + override def getShort(i: Int): Short = { + if (values(i) == null) sys.error("Failed to check null bit for primitive short value.") + getAs[Short](i) + } + override def getInt(i: Int): Int = { if (values(i) == null) sys.error("Failed to check null bit for primitive int value.") - values(i).asInstanceOf[Int] + getAs[Int](i) } override def getLong(i: Int): Long = { if (values(i) == null) sys.error("Failed to check null bit for primitive long value.") - values(i).asInstanceOf[Long] - } - - override def getDouble(i: Int): Double = { - if (values(i) == null) sys.error("Failed to check null bit for primitive double value.") - values(i).asInstanceOf[Double] + getAs[Long](i) } override def getFloat(i: Int): Float = { if (values(i) == null) sys.error("Failed to check null bit for primitive float value.") - values(i).asInstanceOf[Float] + getAs[Float](i) } - override def getBoolean(i: Int): Boolean = { - if (values(i) == null) sys.error("Failed to check null bit for primitive boolean value.") - values(i).asInstanceOf[Boolean] + override def getDouble(i: Int): Double = { + if (values(i) == null) sys.error("Failed to check null bit for primitive double value.") + getAs[Double](i) } +} - override def getShort(i: Int): Short = { - if (values(i) == null) sys.error("Failed to check null bit for primitive short value.") - values(i).asInstanceOf[Short] - } - override def getByte(i: Int): Byte = { - if (values(i) == null) sys.error("Failed to check null bit for primitive byte value.") - values(i).asInstanceOf[Byte] - } +/** + * A row with no data. Calling any methods will result in an error. Can be used as a placeholder. + */ +object EmptyRow extends InternalRow { + override def apply(i: Int): Any = throw new UnsupportedOperationException + override def toSeq: Seq[Any] = Seq.empty + override def length: Int = 0 + override def isNullAt(i: Int): Boolean = throw new UnsupportedOperationException + override def getBoolean(i: Int): Boolean = throw new UnsupportedOperationException + override def getByte(i: Int): Byte = throw new UnsupportedOperationException + override def getShort(i: Int): Short = throw new UnsupportedOperationException + override def getInt(i: Int): Int = throw new UnsupportedOperationException + override def getLong(i: Int): Long = throw new UnsupportedOperationException + override def getFloat(i: Int): Float = throw new UnsupportedOperationException + override def getDouble(i: Int): Double = throw new UnsupportedOperationException + override def getString(i: Int): String = throw new UnsupportedOperationException +} + +/** + * A row implementation that uses an array of objects as the underlying storage. Note that, while + * the array is not copied, and thus could technically be mutated after creation, this is not + * allowed. + */ +class GenericRow(protected[sql] val values: Array[Any]) extends InternalRow with ArrayBackedRow { + /** No-arg constructor for serialization. */ + protected def this() = this(null) + + def this(size: Int) = this(new Array[Any](size)) override def getString(i: Int): String = { values(i) match { @@ -120,8 +135,6 @@ class GenericRow(protected[sql] val values: Array[Any]) extends InternalRow { case utf8: UTF8String => utf8.toString } } - - override def copy(): InternalRow = this } class GenericRowWithSchema(values: Array[Any], override val schema: StructType) @@ -133,7 +146,8 @@ class GenericRowWithSchema(values: Array[Any], override val schema: StructType) override def fieldIndex(name: String): Int = schema.fieldIndex(name) } -class GenericMutableRow(v: Array[Any]) extends GenericRow(v) with MutableRow { +class GenericMutableRow(protected[sql] val values: Array[Any]) + extends MutableRow with ArrayBackedRow { /** No-arg constructor for serialization. */ protected def this() = this(null) @@ -141,21 +155,25 @@ class GenericMutableRow(v: Array[Any]) extends GenericRow(v) with MutableRow { override def setBoolean(ordinal: Int, value: Boolean): Unit = { values(ordinal) = value } override def setByte(ordinal: Int, value: Byte): Unit = { values(ordinal) = value } - override def setDouble(ordinal: Int, value: Double): Unit = { values(ordinal) = value } - override def setFloat(ordinal: Int, value: Float): Unit = { values(ordinal) = value } + override def setShort(ordinal: Int, value: Short): Unit = { values(ordinal) = value } override def setInt(ordinal: Int, value: Int): Unit = { values(ordinal) = value } override def setLong(ordinal: Int, value: Long): Unit = { values(ordinal) = value } - override def setString(ordinal: Int, value: String): Unit = { - values(ordinal) = UTF8String.fromString(value) - } + override def setFloat(ordinal: Int, value: Float): Unit = { values(ordinal) = value } + override def setDouble(ordinal: Int, value: Double): Unit = { values(ordinal) = value } override def setNullAt(i: Int): Unit = { values(i) = null } - override def setShort(ordinal: Int, value: Short): Unit = { values(ordinal) = value } - override def update(ordinal: Int, value: Any): Unit = { values(ordinal) = value } override def copy(): InternalRow = new GenericRow(values.clone()) + + override def getString(i: Int): String = { + values(i) match { + case null => null + case s: String => s + case utf8: UTF8String => utf8.toString + } + } }