From c42cc51e25bf8b88aac08bec3c1f636b5d1a8a3a Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Thu, 16 Jun 2016 20:09:19 +0900 Subject: [PATCH 1/3] Backport correctly merging AFTAggregator to branch 1.6 --- .../ml/regression/AFTSurvivalRegression.scala | 2 +- .../regression/AFTSurvivalRegressionSuite.scala | 17 +++++++++++++++++ 2 files changed, 18 insertions(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala index aedfb48058dc5..cc1d19e4a81ff 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala @@ -496,7 +496,7 @@ private class AFTAggregator(parameters: BDV[Double], fitIntercept: Boolean) * @return This AFTAggregator object. */ def merge(other: AFTAggregator): this.type = { - if (totalCnt != 0) { + if (other.count != 0) { totalCnt += other.totalCnt lossSum += other.lossSum diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala index d718ef63b531a..339a26934e0ff 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala @@ -346,6 +346,23 @@ class AFTSurvivalRegressionSuite testEstimatorAndModelReadWrite(aft, datasetMultivariate, AFTSurvivalRegressionSuite.allParamSettings, checkModelData) } + + test("SPARK-15892: Incorrectly merged AFTAggregator with zero total count") { + // This `dataset` will contain an empty partition because it has five rows but + // the parallelism is bigger than that. Because the issue was about `AFTAggregator`s + // being merged incorrectly when it has an empty partition, running the codes below + // should not throw an exception. + val points = sc.parallelize(Seq( + AFTPoint(Vectors.dense(1.560, -0.605), 1.218, 1.0), + AFTPoint(Vectors.dense(0.346, 2.158), 2.949, 0.0), + AFTPoint(Vectors.dense(1.380, 0.231), 3.627, 0.0), + AFTPoint(Vectors.dense(0.520, 1.151), 0.273, 1.0), + AFTPoint(Vectors.dense(0.795, -0.226), 4.199, 0.0)), numSlices = 6) + val dataset = sqlContext.createDataFrame(points) + val trainer = new AFTSurvivalRegression() + val model = trainer.fit(dataset) + assert(model.scale != 1) + } } object AFTSurvivalRegressionSuite { From b0203fcaeeb865bfe21bfff0475fa1be50c4f5ba Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Fri, 17 Jun 2016 10:13:56 +0900 Subject: [PATCH 2/3] Update comments --- .../spark/ml/regression/AFTSurvivalRegressionSuite.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala index 339a26934e0ff..c6d1b875ef858 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala @@ -350,8 +350,8 @@ class AFTSurvivalRegressionSuite test("SPARK-15892: Incorrectly merged AFTAggregator with zero total count") { // This `dataset` will contain an empty partition because it has five rows but // the parallelism is bigger than that. Because the issue was about `AFTAggregator`s - // being merged incorrectly when it has an empty partition, running the codes below - // should not throw an exception. + // being merged incorrectly when it has an empty partition, the trained model always has + // 1.0 scale from Euler's number for 0. val points = sc.parallelize(Seq( AFTPoint(Vectors.dense(1.560, -0.605), 1.218, 1.0), AFTPoint(Vectors.dense(0.346, 2.158), 2.949, 0.0), From 9cf8152c6121faa37a9cd337fddfa22f1457fc43 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Fri, 17 Jun 2016 10:24:51 +0900 Subject: [PATCH 3/3] Fix coment --- .../apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala index c6d1b875ef858..70f9693b4e96b 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala @@ -350,7 +350,7 @@ class AFTSurvivalRegressionSuite test("SPARK-15892: Incorrectly merged AFTAggregator with zero total count") { // This `dataset` will contain an empty partition because it has five rows but // the parallelism is bigger than that. Because the issue was about `AFTAggregator`s - // being merged incorrectly when it has an empty partition, the trained model always has + // being merged incorrectly when it has an empty partition, the trained model has // 1.0 scale from Euler's number for 0. val points = sc.parallelize(Seq( AFTPoint(Vectors.dense(1.560, -0.605), 1.218, 1.0),