From fa5e3f94801bbdeb5f2d303a79fb0ebd89733436 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9dric=20Pelvet?= Date: Thu, 17 Aug 2017 23:48:58 +0200 Subject: [PATCH 1/3] Correct validateAndTransformSchema --- .../org/apache/spark/ml/clustering/GaussianMixture.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala index 5259ee419445f..90fd1268e9568 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala @@ -64,8 +64,8 @@ private[clustering] trait GaussianMixtureParams extends Params with HasMaxIter w */ protected def validateAndTransformSchema(schema: StructType): StructType = { SchemaUtils.checkColumnType(schema, $(featuresCol), new VectorUDT) - SchemaUtils.appendColumn(schema, $(predictionCol), IntegerType) - SchemaUtils.appendColumn(schema, $(probabilityCol), new VectorUDT) + val schema_temp = SchemaUtils.appendColumn(schema, $(predictionCol), IntegerType) + SchemaUtils.appendColumn(schema_temp, $(probabilityCol), new VectorUDT) } } From b2df776c5f48bfb1c3e5e0cda82b901405f9c6c5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9dric=20Pelvet?= Date: Thu, 17 Aug 2017 23:48:58 +0200 Subject: [PATCH 2/3] Correct validateAndTransformSchema --- .../org/apache/spark/ml/clustering/GaussianMixture.scala | 4 ++-- .../spark/ml/regression/AFTSurvivalRegression.scala | 8 +++++--- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala index 5259ee419445f..90fd1268e9568 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala @@ -64,8 +64,8 @@ private[clustering] trait GaussianMixtureParams extends Params with HasMaxIter w */ protected def validateAndTransformSchema(schema: StructType): StructType = { SchemaUtils.checkColumnType(schema, $(featuresCol), new VectorUDT) - SchemaUtils.appendColumn(schema, $(predictionCol), IntegerType) - SchemaUtils.appendColumn(schema, $(probabilityCol), new VectorUDT) + val schema_temp = SchemaUtils.appendColumn(schema, $(predictionCol), IntegerType) + SchemaUtils.appendColumn(schema_temp, $(probabilityCol), new VectorUDT) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala index 0891994530f88..56054cea907cd 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala @@ -109,10 +109,12 @@ private[regression] trait AFTSurvivalRegressionParams extends Params SchemaUtils.checkNumericType(schema, $(censorCol)) SchemaUtils.checkNumericType(schema, $(labelCol)) } - if (hasQuantilesCol) { + + val temp_schema = if (hasQuantilesCol) { SchemaUtils.appendColumn(schema, $(quantilesCol), new VectorUDT) - } - SchemaUtils.appendColumn(schema, $(predictionCol), DoubleType) + } else schema + + SchemaUtils.appendColumn(temp_schema, $(predictionCol), DoubleType) } } From d14c21e14b37501f75cf39244da7d93fa66b47fc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9dric=20Pelvet?= Date: Sat, 19 Aug 2017 12:31:18 +0200 Subject: [PATCH 3/3] Renamed schema_temp with more descriptive camelCase variable names. --- .../org/apache/spark/ml/clustering/GaussianMixture.scala | 4 ++-- .../apache/spark/ml/regression/AFTSurvivalRegression.scala | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala index 90fd1268e9568..f19ad7a5a6938 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala @@ -64,8 +64,8 @@ private[clustering] trait GaussianMixtureParams extends Params with HasMaxIter w */ protected def validateAndTransformSchema(schema: StructType): StructType = { SchemaUtils.checkColumnType(schema, $(featuresCol), new VectorUDT) - val schema_temp = SchemaUtils.appendColumn(schema, $(predictionCol), IntegerType) - SchemaUtils.appendColumn(schema_temp, $(probabilityCol), new VectorUDT) + val schemaWithPredictionCol = SchemaUtils.appendColumn(schema, $(predictionCol), IntegerType) + SchemaUtils.appendColumn(schemaWithPredictionCol, $(probabilityCol), new VectorUDT) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala index 56054cea907cd..16821f317760e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala @@ -110,11 +110,11 @@ private[regression] trait AFTSurvivalRegressionParams extends Params SchemaUtils.checkNumericType(schema, $(labelCol)) } - val temp_schema = if (hasQuantilesCol) { + val schemaWithQuantilesCol = if (hasQuantilesCol) { SchemaUtils.appendColumn(schema, $(quantilesCol), new VectorUDT) } else schema - SchemaUtils.appendColumn(temp_schema, $(predictionCol), DoubleType) + SchemaUtils.appendColumn(schemaWithQuantilesCol, $(predictionCol), DoubleType) } }