Skip to content

Commit d206860

Browse files
committed
[SPARK-7066][MLlib] VectorAssembler should use NumericType not NativeType.
Author: Reynold Xin <[email protected]> Closes apache#5642 from rxin/mllib-native-type and squashes the following commits: e23af5b [Reynold Xin] Remove StringType 7cbb205 [Reynold Xin] [SPARK-7066][MLlib] VectorAssembler should use NumericType and StringType, not NativeType.
1 parent 1b85e08 commit d206860

File tree

2 files changed

+5
-4
lines changed

2 files changed

+5
-4
lines changed

mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,8 @@ class VectorAssembler extends Transformer with HasInputCols with HasOutputCol {
5555
schema(c).dataType match {
5656
case DoubleType => UnresolvedAttribute(c)
5757
case t if t.isInstanceOf[VectorUDT] => UnresolvedAttribute(c)
58-
case _: NativeType => Alias(Cast(UnresolvedAttribute(c), DoubleType), s"${c}_double_$uid")()
58+
case _: NumericType =>
59+
Alias(Cast(UnresolvedAttribute(c), DoubleType), s"${c}_double_$uid")()
5960
}
6061
}
6162
dataset.select(col("*"), assembleFunc(new Column(CreateStruct(args))).as(map(outputCol)))
@@ -67,7 +68,7 @@ class VectorAssembler extends Transformer with HasInputCols with HasOutputCol {
6768
val outputColName = map(outputCol)
6869
val inputDataTypes = inputColNames.map(name => schema(name).dataType)
6970
inputDataTypes.foreach {
70-
case _: NativeType =>
71+
case _: NumericType =>
7172
case t if t.isInstanceOf[VectorUDT] =>
7273
case other =>
7374
throw new IllegalArgumentException(s"Data type $other is not supported.")

sql/catalyst/src/main/scala/org/apache/spark/sql/types/dataTypes.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -299,7 +299,7 @@ class NullType private() extends DataType {
299299
case object NullType extends NullType
300300

301301

302-
protected[spark] object NativeType {
302+
protected[sql] object NativeType {
303303
val all = Seq(
304304
IntegerType, BooleanType, LongType, DoubleType, FloatType, ShortType, ByteType, StringType)
305305

@@ -327,7 +327,7 @@ protected[sql] object PrimitiveType {
327327
}
328328
}
329329

330-
protected[spark] abstract class NativeType extends DataType {
330+
protected[sql] abstract class NativeType extends DataType {
331331
private[sql] type JvmType
332332
@transient private[sql] val tag: TypeTag[JvmType]
333333
private[sql] val ordering: Ordering[JvmType]

0 commit comments

Comments
 (0)