From 2de848f6e38551cceb59111e9a8e0663b9a8d729 Mon Sep 17 00:00:00 2001 From: Earthson Lu Date: Wed, 17 Feb 2016 17:43:01 +0800 Subject: [PATCH 1/2] backport from master --- .../spark/ml/feature/CountVectorizer.scala | 3 ++- .../org/apache/spark/ml/util/SchemaUtils.scala | 17 +++++++++++++++++ 2 files changed, 19 insertions(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala index 49028e4b8506..2f532d11ac88 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala @@ -68,7 +68,8 @@ private[feature] trait CountVectorizerParams extends Params with HasInputCol wit /** Validates and transforms the input schema. */ protected def validateAndTransformSchema(schema: StructType): StructType = { - SchemaUtils.checkColumnType(schema, $(inputCol), new ArrayType(StringType, true)) + val typeCandidates = List(ArrayType(StringType, true), ArrayType(StringType, false)) + SchemaUtils.checkColumnTypes(schema, $(inputCol), typeCandidates) SchemaUtils.appendColumn(schema, $(outputCol), new VectorUDT) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala b/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala index 76f651488aef..ab4e0456b2c8 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala @@ -43,6 +43,23 @@ private[spark] object SchemaUtils { s"Column $colName must be of type $dataType but was actually $actualDataType.$message") } + /** + * Check whether the given schema contains a column of one of the require data types. + * @param colName column name + * @param dataTypes required column data types + */ + def checkColumnTypes( + schema: StructType, + colName: String, + dataTypes: Seq[DataType], + msg: String = ""): Unit = { + val actualDataType = schema(colName).dataType + val message = if (msg != null && msg.trim.length > 0) " " + msg else "" + require(dataTypes.exists(actualDataType.equals), + s"Column $colName must be of type equal to one of the following types: " + + s"${dataTypes.mkString("[", ", ", "]")} but was actually of type $actualDataType.$message") + } + /** * Appends a new column to the input schema. This fails if the given output column already exists. * @param schema input schema From 30655f5d06f3ea8a137c4950ac9668c3966afda5 Mon Sep 17 00:00:00 2001 From: Earthson Lu Date: Fri, 19 Feb 2016 10:50:41 +0800 Subject: [PATCH 2/2] fix style problem --- .../main/scala/org/apache/spark/ml/util/SchemaUtils.scala | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala b/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala index ab4e0456b2c8..e71dd9eee03e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala @@ -49,10 +49,10 @@ private[spark] object SchemaUtils { * @param dataTypes required column data types */ def checkColumnTypes( - schema: StructType, - colName: String, - dataTypes: Seq[DataType], - msg: String = ""): Unit = { + schema: StructType, + colName: String, + dataTypes: Seq[DataType], + msg: String = ""): Unit = { val actualDataType = schema(colName).dataType val message = if (msg != null && msg.trim.length > 0) " " + msg else "" require(dataTypes.exists(actualDataType.equals),