From 7c3d60f5cc1a8d075a04c59ca4785d9e80240374 Mon Sep 17 00:00:00 2001 From: Grzegorz Chilkiewicz Date: Wed, 13 Jan 2016 14:27:35 +0100 Subject: [PATCH 1/4] [SPARK-12711] ML StopWordsRemover does not protect itself from column name duplication --- .../org/apache/spark/ml/feature/StopWordsRemover.scala | 4 +--- .../main/scala/org/apache/spark/ml/util/SchemaUtils.scala | 8 +++----- 2 files changed, 4 insertions(+), 8 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala index b93c9ed382bd..e53ef300f644 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala @@ -149,9 +149,7 @@ class StopWordsRemover(override val uid: String) val inputType = schema($(inputCol)).dataType require(inputType.sameType(ArrayType(StringType)), s"Input type must be ArrayType(StringType) but got $inputType.") - val outputFields = schema.fields :+ - StructField($(outputCol), inputType, schema($(inputCol)).nullable) - StructType(outputFields) + SchemaUtils.appendColumn(schema, $(outputCol), inputType, schema($(inputCol)).nullable) } override def copy(extra: ParamMap): StopWordsRemover = defaultCopy(extra) diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala b/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala index 76f651488aef..7decbbd0b96b 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala @@ -54,12 +54,10 @@ private[spark] object SchemaUtils { def appendColumn( schema: StructType, colName: String, - dataType: DataType): StructType = { + dataType: DataType, + nullable: Boolean = false): StructType = { if (colName.isEmpty) return schema - val fieldNames = schema.fieldNames - require(!fieldNames.contains(colName), s"Column $colName already exists.") - val outputFields = schema.fields :+ StructField(colName, dataType, nullable = false) - StructType(outputFields) + appendColumn(schema, StructField(colName, dataType, nullable)) } /** From 1474b6fe51007f912646f41b9d30af060edb8f3b Mon Sep 17 00:00:00 2001 From: Grzegorz Chilkiewicz Date: Wed, 13 Jan 2016 14:27:35 +0100 Subject: [PATCH 2/4] [SPARK-12711] ML StopWordsRemover does not protect itself from column name duplication --- .../spark/ml/feature/StopWordsRemover.scala | 4 +--- .../org/apache/spark/ml/util/SchemaUtils.scala | 8 +++----- .../ml/feature/StopWordsRemoverSuite.scala | 18 ++++++++++++++++++ 3 files changed, 22 insertions(+), 8 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala index b93c9ed382bd..e53ef300f644 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala @@ -149,9 +149,7 @@ class StopWordsRemover(override val uid: String) val inputType = schema($(inputCol)).dataType require(inputType.sameType(ArrayType(StringType)), s"Input type must be ArrayType(StringType) but got $inputType.") - val outputFields = schema.fields :+ - StructField($(outputCol), inputType, schema($(inputCol)).nullable) - StructType(outputFields) + SchemaUtils.appendColumn(schema, $(outputCol), inputType, schema($(inputCol)).nullable) } override def copy(extra: ParamMap): StopWordsRemover = defaultCopy(extra) diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala b/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala index 76f651488aef..7decbbd0b96b 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala @@ -54,12 +54,10 @@ private[spark] object SchemaUtils { def appendColumn( schema: StructType, colName: String, - dataType: DataType): StructType = { + dataType: DataType, + nullable: Boolean = false): StructType = { if (colName.isEmpty) return schema - val fieldNames = schema.fieldNames - require(!fieldNames.contains(colName), s"Column $colName already exists.") - val outputFields = schema.fields :+ StructField(colName, dataType, nullable = false) - StructType(outputFields) + appendColumn(schema, StructField(colName, dataType, nullable)) } /** diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StopWordsRemoverSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StopWordsRemoverSuite.scala index fb217e0c1de9..2f9589033954 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/StopWordsRemoverSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StopWordsRemoverSuite.scala @@ -89,4 +89,22 @@ class StopWordsRemoverSuite .setCaseSensitive(true) testDefaultReadWrite(t) } + + test("StopWordsRemover output column already exists") { + val outpuCol = "expected" + val remover = new StopWordsRemover() + .setInputCol("raw") + .setOutputCol(outpuCol) + .setCaseSensitive(true) + val dataSet = sqlContext.createDataFrame(Seq( + (Seq("A"), Seq("A")), + (Seq("The", "the"), Seq("The")) + )).toDF("raw", outpuCol) + + val thrown = intercept[IllegalArgumentException] { + testStopWordsRemover(remover, dataSet) + } + assert(thrown.getClass === classOf[IllegalArgumentException]) + assert(thrown.getMessage == s"requirement failed: Column ${outpuCol} already exists.") + } } From 37af3910430203793eddbc70988480ba33bfeef7 Mon Sep 17 00:00:00 2001 From: Grzegorz Chilkiewicz Date: Thu, 14 Jan 2016 09:18:08 +0100 Subject: [PATCH 3/4] [SPARK-12711][ML] ML StopWordsRemover does not protect itself from column name duplication Fixes problem and verifies fix by test suite. Also - adds optional parameter nullable (Boolean) to: SchemaUtils.appendColumn and deduplicates SchemaUtils.appendColumn functions. --- .../spark/ml/feature/StopWordsRemoverSuite.scala | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StopWordsRemoverSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StopWordsRemoverSuite.scala index 2f9589033954..a41ede9f3def 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/StopWordsRemoverSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StopWordsRemoverSuite.scala @@ -91,20 +91,18 @@ class StopWordsRemoverSuite } test("StopWordsRemover output column already exists") { - val outpuCol = "expected" + val outputCol = "expected" val remover = new StopWordsRemover() .setInputCol("raw") - .setOutputCol(outpuCol) - .setCaseSensitive(true) + .setOutputCol(outputCol) val dataSet = sqlContext.createDataFrame(Seq( - (Seq("A"), Seq("A")), - (Seq("The", "the"), Seq("The")) - )).toDF("raw", outpuCol) + (Seq("A"), Seq()), + (Seq("The", "the"), Seq()) + )).toDF("raw", outputCol) val thrown = intercept[IllegalArgumentException] { testStopWordsRemover(remover, dataSet) } - assert(thrown.getClass === classOf[IllegalArgumentException]) - assert(thrown.getMessage == s"requirement failed: Column ${outpuCol} already exists.") + assert(thrown.getMessage == s"requirement failed: Column $outputCol already exists.") } } From 49fd362d5f9426c74427fa089de750c3c4e67a2d Mon Sep 17 00:00:00 2001 From: Grzegorz Chilkiewicz Date: Wed, 27 Jan 2016 16:24:21 +0100 Subject: [PATCH 4/4] Fix empty DataFrame column error --- .../org/apache/spark/ml/feature/StopWordsRemoverSuite.scala | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StopWordsRemoverSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StopWordsRemoverSuite.scala index a41ede9f3def..a5b24c18565b 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/StopWordsRemoverSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StopWordsRemoverSuite.scala @@ -96,8 +96,7 @@ class StopWordsRemoverSuite .setInputCol("raw") .setOutputCol(outputCol) val dataSet = sqlContext.createDataFrame(Seq( - (Seq("A"), Seq()), - (Seq("The", "the"), Seq()) + (Seq("The", "the", "swift"), Seq("swift")) )).toDF("raw", outputCol) val thrown = intercept[IllegalArgumentException] {