Skip to content

Commit e92fb96

Browse files
HyukjinKwonjkbradley
authored andcommitted
[SPARK-15892][ML] Incorrectly merged AFTAggregator with zero total count
## What changes were proposed in this pull request? Currently, `AFTAggregator` is not being merged correctly. For example, if there is any single empty partition in the data, this creates an `AFTAggregator` with zero total count which causes the exception below: ``` IllegalArgumentException: u'requirement failed: The number of instances should be greater than 0.0, but got 0.' ``` Please see [AFTSurvivalRegression.scala#L573-L575](https://github.com/apache/spark/blob/6ecedf39b44c9acd58cdddf1a31cf11e8e24428c/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala#L573-L575) as well. Just to be clear, the python example `aft_survival_regression.py` seems using 5 rows. So, if there exist partitions more than 5, it throws the exception above since it contains empty partitions which results in an incorrectly merged `AFTAggregator`. Executing `bin/spark-submit examples/src/main/python/ml/aft_survival_regression.py` on a machine with CPUs more than 5 is being failed because it creates tasks with some empty partitions with defualt configurations (AFAIK, it sets the parallelism level to the number of CPU cores). ## How was this patch tested? An unit test in `AFTSurvivalRegressionSuite.scala` and manually tested by `bin/spark-submit examples/src/main/python/ml/aft_survival_regression.py`. Author: hyukjinkwon <[email protected]> Author: Hyukjin Kwon <[email protected]> Closes #13619 from HyukjinKwon/SPARK-15892. (cherry picked from commit e355460) Signed-off-by: Joseph K. Bradley <[email protected]>
1 parent b699a7b commit e92fb96

File tree

2 files changed

+13
-1
lines changed

2 files changed

+13
-1
lines changed

mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -538,7 +538,7 @@ private class AFTAggregator(
538538
* @return This AFTAggregator object.
539539
*/
540540
def merge(other: AFTAggregator): this.type = {
541-
if (totalCnt != 0) {
541+
if (other.count != 0) {
542542
totalCnt += other.totalCnt
543543
lossSum += other.lossSum
544544

mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -390,6 +390,18 @@ class AFTSurvivalRegressionSuite
390390
testEstimatorAndModelReadWrite(aft, datasetMultivariate,
391391
AFTSurvivalRegressionSuite.allParamSettings, checkModelData)
392392
}
393+
394+
test("SPARK-15892: Incorrectly merged AFTAggregator with zero total count") {
395+
// This `dataset` will contain an empty partition because it has two rows but
396+
// the parallelism is bigger than that. Because the issue was about `AFTAggregator`s
397+
// being merged incorrectly when it has an empty partition, running the codes below
398+
// should not throw an exception.
399+
val dataset = spark.createDataFrame(
400+
sc.parallelize(generateAFTInput(
401+
1, Array(5.5), Array(0.8), 2, 42, 1.0, 2.0, 2.0), numSlices = 3))
402+
val trainer = new AFTSurvivalRegression()
403+
trainer.fit(dataset)
404+
}
393405
}
394406

395407
object AFTSurvivalRegressionSuite {

0 commit comments

Comments
 (0)