Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -154,9 +154,9 @@ private[spark] class MatrixUDT extends UserDefinedType[Matrix] {
row.setByte(0, 0)
row.setInt(1, sm.numRows)
row.setInt(2, sm.numCols)
row.update(3, sm.colPtrs.toSeq)
row.update(4, sm.rowIndices.toSeq)
row.update(5, sm.values.toSeq)
row.update(3, new GenericArrayData(sm.colPtrs.map(_.asInstanceOf[Any])))
row.update(4, new GenericArrayData(sm.rowIndices.map(_.asInstanceOf[Any])))
row.update(5, new GenericArrayData(sm.values.map(_.asInstanceOf[Any])))
row.setBoolean(6, sm.isTransposed)

case dm: DenseMatrix =>
Expand All @@ -165,7 +165,7 @@ private[spark] class MatrixUDT extends UserDefinedType[Matrix] {
row.setInt(2, dm.numCols)
row.setNullAt(3)
row.setNullAt(4)
row.update(5, dm.values.toSeq)
row.update(5, new GenericArrayData(dm.values.map(_.asInstanceOf[Any])))
row.setBoolean(6, dm.isTransposed)
}
row
Expand All @@ -179,14 +179,12 @@ private[spark] class MatrixUDT extends UserDefinedType[Matrix] {
val tpe = row.getByte(0)
val numRows = row.getInt(1)
val numCols = row.getInt(2)
val values = row.getAs[Seq[Double]](5, ArrayType(DoubleType, containsNull = false)).toArray
val values = row.getArray(5).toArray.map(_.asInstanceOf[Double])
val isTransposed = row.getBoolean(6)
tpe match {
case 0 =>
val colPtrs =
row.getAs[Seq[Int]](3, ArrayType(IntegerType, containsNull = false)).toArray
val rowIndices =
row.getAs[Seq[Int]](4, ArrayType(IntegerType, containsNull = false)).toArray
val colPtrs = row.getArray(3).toArray.map(_.asInstanceOf[Int])
val rowIndices = row.getArray(4).toArray.map(_.asInstanceOf[Int])
new SparseMatrix(numRows, numCols, colPtrs, rowIndices, values, isTransposed)
case 1 =>
new DenseMatrix(numRows, numCols, values, isTransposed)
Expand Down
15 changes: 6 additions & 9 deletions mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
Original file line number Diff line number Diff line change
Expand Up @@ -187,15 +187,15 @@ private[spark] class VectorUDT extends UserDefinedType[Vector] {
val row = new GenericMutableRow(4)
row.setByte(0, 0)
row.setInt(1, size)
row.update(2, indices.toSeq)
row.update(3, values.toSeq)
row.update(2, new GenericArrayData(indices.map(_.asInstanceOf[Any])))
row.update(3, new GenericArrayData(values.map(_.asInstanceOf[Any])))
row
case DenseVector(values) =>
val row = new GenericMutableRow(4)
row.setByte(0, 1)
row.setNullAt(1)
row.setNullAt(2)
row.update(3, values.toSeq)
row.update(3, new GenericArrayData(values.map(_.asInstanceOf[Any])))
row
}
}
Expand All @@ -209,14 +209,11 @@ private[spark] class VectorUDT extends UserDefinedType[Vector] {
tpe match {
case 0 =>
val size = row.getInt(1)
val indices =
row.getAs[Seq[Int]](2, ArrayType(IntegerType, containsNull = false)).toArray
val values =
row.getAs[Seq[Double]](3, ArrayType(DoubleType, containsNull = false)).toArray
val indices = row.getArray(2).toArray().map(_.asInstanceOf[Int])
val values = row.getArray(3).toArray().map(_.asInstanceOf[Double])
new SparseVector(size, indices, values)
case 1 =>
val values =
row.getAs[Seq[Double]](3, ArrayType(DoubleType, containsNull = false)).toArray
val values = row.getArray(3).toArray().map(_.asInstanceOf[Double])
new DenseVector(values)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.spark.sql.catalyst.expressions;

import org.apache.spark.sql.catalyst.InternalRow;
import org.apache.spark.sql.types.ArrayData;
import org.apache.spark.sql.types.Decimal;
import org.apache.spark.unsafe.types.CalendarInterval;
import org.apache.spark.unsafe.types.UTF8String;
Expand Down Expand Up @@ -50,4 +51,5 @@ public interface SpecializedGetters {

InternalRow getStruct(int ordinal, int numFields);

ArrayData getArray(int ordinal);
}
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,6 @@ object CatalystTypeConverters {

private def isWholePrimitive(dt: DataType): Boolean = dt match {
case dt if isPrimitive(dt) => true
case ArrayType(elementType, _) => isWholePrimitive(elementType)
case MapType(keyType, valueType, _) => isWholePrimitive(keyType) && isWholePrimitive(valueType)
case _ => false
}
Expand Down Expand Up @@ -154,39 +153,41 @@ object CatalystTypeConverters {

/** Converter for arrays, sequences, and Java iterables. */
private case class ArrayConverter(
elementType: DataType) extends CatalystTypeConverter[Any, Seq[Any], Seq[Any]] {
elementType: DataType) extends CatalystTypeConverter[Any, Seq[Any], ArrayData] {

private[this] val elementConverter = getConverterForType(elementType)

private[this] val isNoChange = isWholePrimitive(elementType)

override def toCatalystImpl(scalaValue: Any): Seq[Any] = {
override def toCatalystImpl(scalaValue: Any): ArrayData = {
scalaValue match {
case a: Array[_] => a.toSeq.map(elementConverter.toCatalyst)
case s: Seq[_] => s.map(elementConverter.toCatalyst)
case a: Array[_] =>
new GenericArrayData(a.map(elementConverter.toCatalyst))
case s: Seq[_] =>
new GenericArrayData(s.map(elementConverter.toCatalyst).toArray)
case i: JavaIterable[_] =>
val iter = i.iterator
var convertedIterable: List[Any] = List()
val convertedIterable = scala.collection.mutable.ArrayBuffer.empty[Any]
while (iter.hasNext) {
val item = iter.next()
convertedIterable :+= elementConverter.toCatalyst(item)
convertedIterable += elementConverter.toCatalyst(item)
}
convertedIterable
new GenericArrayData(convertedIterable.toArray)
}
}

override def toScala(catalystValue: Seq[Any]): Seq[Any] = {
override def toScala(catalystValue: ArrayData): Seq[Any] = {
if (catalystValue == null) {
null
} else if (isNoChange) {
catalystValue
catalystValue.toArray()
} else {
catalystValue.map(elementConverter.toScala)
catalystValue.toArray().map(elementConverter.toScala)
}
}

override def toScalaImpl(row: InternalRow, column: Int): Seq[Any] =
toScala(row.get(column, ArrayType(elementType)).asInstanceOf[Seq[Any]])
toScala(row.getArray(column))
}

private case class MapConverter(
Expand Down Expand Up @@ -402,9 +403,9 @@ object CatalystTypeConverters {
case t: Timestamp => TimestampConverter.toCatalyst(t)
case d: BigDecimal => BigDecimalConverter.toCatalyst(d)
case d: JavaBigDecimal => BigDecimalConverter.toCatalyst(d)
case seq: Seq[Any] => seq.map(convertToCatalyst)
case seq: Seq[Any] => new GenericArrayData(seq.map(convertToCatalyst).toArray)
case r: Row => InternalRow(r.toSeq.map(convertToCatalyst): _*)
case arr: Array[Any] => arr.map(convertToCatalyst)
case arr: Array[Any] => new GenericArrayData(arr.map(convertToCatalyst))
case m: Map[_, _] =>
m.map { case (k, v) => (convertToCatalyst(k), convertToCatalyst(v)) }.toMap
case other => other
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,8 @@ abstract class InternalRow extends Serializable with SpecializedGetters {
override def getStruct(ordinal: Int, numFields: Int): InternalRow =
getAs[InternalRow](ordinal, null)

override def getArray(ordinal: Int): ArrayData = getAs(ordinal, null)

override def toString: String = s"[${this.mkString(",")}]"

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean)

override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
val javaType = ctx.javaType(dataType)
val value = ctx.getColumn("i", dataType, ordinal)
val value = ctx.getValue("i", dataType, ordinal.toString)
s"""
boolean ${ev.isNull} = i.isNullAt($ordinal);
$javaType ${ev.primitive} = ${ev.isNull} ? ${ctx.defaultValue(dataType)} : ($value);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -363,7 +363,21 @@ case class Cast(child: Expression, dataType: DataType)

private[this] def castArray(from: ArrayType, to: ArrayType): Any => Any = {
val elementCast = cast(from.elementType, to.elementType)
buildCast[Seq[Any]](_, _.map(v => if (v == null) null else elementCast(v)))
// TODO: Could be faster?
buildCast[ArrayData](_, array => {
val length = array.numElements()
val values = new Array[Any](length)
var i = 0
while (i < length) {
if (array.isNullAt(i)) {
values(i) = null
} else {
values(i) = elementCast(array.get(i))
}
i += 1
}
new GenericArrayData(values)
})
}

private[this] def castMap(from: MapType, to: MapType): Any => Any = {
Expand Down Expand Up @@ -789,37 +803,36 @@ case class Cast(child: Expression, dataType: DataType)
private[this] def castArrayCode(
from: ArrayType, to: ArrayType, ctx: CodeGenContext): CastFunction = {
val elementCast = nullSafeCastFunction(from.elementType, to.elementType, ctx)

val arraySeqClass = classOf[mutable.ArraySeq[Any]].getName
val arrayClass = classOf[GenericArrayData].getName
val fromElementNull = ctx.freshName("feNull")
val fromElementPrim = ctx.freshName("fePrim")
val toElementNull = ctx.freshName("teNull")
val toElementPrim = ctx.freshName("tePrim")
val size = ctx.freshName("n")
val j = ctx.freshName("j")
val result = ctx.freshName("result")
val values = ctx.freshName("values")

(c, evPrim, evNull) =>
s"""
final int $size = $c.size();
final $arraySeqClass<Object> $result = new $arraySeqClass<Object>($size);
final int $size = $c.numElements();
final Object[] $values = new Object[$size];
for (int $j = 0; $j < $size; $j ++) {
if ($c.apply($j) == null) {
$result.update($j, null);
if ($c.isNullAt($j)) {
$values[$j] = null;
} else {
boolean $fromElementNull = false;
${ctx.javaType(from.elementType)} $fromElementPrim =
(${ctx.boxedType(from.elementType)}) $c.apply($j);
${ctx.getValue(c, from.elementType, j)};
${castCode(ctx, fromElementPrim,
fromElementNull, toElementPrim, toElementNull, to.elementType, elementCast)}
if ($toElementNull) {
$result.update($j, null);
$values[$j] = null;
} else {
$result.update($j, $toElementPrim);
$values[$j] = $toElementPrim;
}
}
}
$evPrim = $result;
$evPrim = new $arrayClass($values);
"""
}

Expand Down Expand Up @@ -891,7 +904,7 @@ case class Cast(child: Expression, dataType: DataType)
$result.setNullAt($i);
} else {
$fromType $fromFieldPrim =
${ctx.getColumn(tmpRow, from.fields(i).dataType, i)};
${ctx.getValue(tmpRow, from.fields(i).dataType, i.toString)};
${castCode(ctx, fromFieldPrim,
fromFieldNull, toFieldPrim, toFieldNull, to.fields(i).dataType, cast)}
if ($toFieldNull) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,17 +100,18 @@ class CodeGenContext {
}

/**
* Returns the code to access a column in Row for a given DataType.
* Returns the code to access a value in `SpecializedGetters` for a given DataType.
*/
def getColumn(row: String, dataType: DataType, ordinal: Int): String = {
def getValue(getter: String, dataType: DataType, ordinal: String): String = {
val jt = javaType(dataType)
dataType match {
case _ if isPrimitiveType(jt) => s"$row.get${primitiveTypeName(jt)}($ordinal)"
case StringType => s"$row.getUTF8String($ordinal)"
case BinaryType => s"$row.getBinary($ordinal)"
case CalendarIntervalType => s"$row.getInterval($ordinal)"
case t: StructType => s"$row.getStruct($ordinal, ${t.size})"
case _ => s"($jt)$row.get($ordinal)"
case _ if isPrimitiveType(jt) => s"$getter.get${primitiveTypeName(jt)}($ordinal)"
case StringType => s"$getter.getUTF8String($ordinal)"
case BinaryType => s"$getter.getBinary($ordinal)"
case CalendarIntervalType => s"$getter.getInterval($ordinal)"
case t: StructType => s"$getter.getStruct($ordinal, ${t.size})"
case a: ArrayType => s"$getter.getArray($ordinal)"
case _ => s"($jt)$getter.get($ordinal)" // todo: remove generic getter.
}
}

Expand Down Expand Up @@ -152,8 +153,8 @@ class CodeGenContext {
case StringType => "UTF8String"
case CalendarIntervalType => "CalendarInterval"
case _: StructType => "InternalRow"
case _: ArrayType => s"scala.collection.Seq"
case _: MapType => s"scala.collection.Map"
case _: ArrayType => "ArrayData"
case _: MapType => "scala.collection.Map"
case dt: OpenHashSetUDT if dt.elementType == IntegerType => classOf[IntegerHashSet].getName
case dt: OpenHashSetUDT if dt.elementType == LongType => classOf[LongHashSet].getName
case _ => "Object"
Expand Down Expand Up @@ -214,7 +215,9 @@ class CodeGenContext {
case dt: DataType if isPrimitiveType(dt) => s"($c1 > $c2 ? 1 : $c1 < $c2 ? -1 : 0)"
case BinaryType => s"org.apache.spark.sql.catalyst.util.TypeUtils.compareBinary($c1, $c2)"
case NullType => "0"
case other => s"$c1.compare($c2)"
case other if other.isInstanceOf[AtomicType] => s"$c1.compare($c2)"
case _ => throw new IllegalArgumentException(
"cannot generate compare code for un-comparable type")
}

/**
Expand Down Expand Up @@ -293,7 +296,8 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin
classOf[UnsafeRow].getName,
classOf[UTF8String].getName,
classOf[Decimal].getName,
classOf[CalendarInterval].getName
classOf[CalendarInterval].getName,
classOf[ArrayData].getName
))
evaluator.setExtendedClass(classOf[GeneratedClass])
try {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -153,14 +153,14 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
val nestedStructEv = GeneratedExpressionCode(
code = "",
isNull = s"${input.primitive}.isNullAt($i)",
primitive = s"${ctx.getColumn(input.primitive, dt, i)}"
primitive = s"${ctx.getValue(input.primitive, dt, i.toString)}"
)
createCodeForStruct(ctx, nestedStructEv, st)
case _ =>
GeneratedExpressionCode(
code = "",
isNull = s"${input.primitive}.isNullAt($i)",
primitive = s"${ctx.getColumn(input.primitive, dt, i)}"
primitive = s"${ctx.getValue(input.primitive, dt, i.toString)}"
)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,15 @@ case class Size(child: Expression) extends UnaryExpression with ExpectsInputType
override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(ArrayType, MapType))

override def nullSafeEval(value: Any): Int = child.dataType match {
case ArrayType(_, _) => value.asInstanceOf[Seq[Any]].size
case MapType(_, _, _) => value.asInstanceOf[Map[Any, Any]].size
case _: ArrayType => value.asInstanceOf[ArrayData].numElements()
case _: MapType => value.asInstanceOf[Map[Any, Any]].size
}

override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
nullSafeCodeGen(ctx, ev, c => s"${ev.primitive} = ($c).size();")
val sizeCall = child.dataType match {
case _: ArrayType => "numElements()"
case _: MapType => "size()"
}
nullSafeCodeGen(ctx, ev, c => s"${ev.primitive} = ($c).$sizeCall;")
}
}
Loading