From 81cd88cfb44a1b0328e89cad1727461eec683038 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Tue, 1 Mar 2016 22:25:23 +0800 Subject: [PATCH 1/6] Implement a simple wrapper of AFTSurvivalRegression in SparkR --- R/pkg/NAMESPACE | 3 +- R/pkg/R/generics.R | 4 ++ R/pkg/R/mllib.R | 33 +++++++++++ R/pkg/inst/tests/testthat/test_mllib.R | 14 +++++ .../apache/spark/ml/r/SparkRWrappers.scala | 56 ++++++++++++++++++- .../ml/regression/AFTSurvivalRegression.scala | 45 ++++++++++++++- 6 files changed, 151 insertions(+), 4 deletions(-) diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index 636d39e1e9ca..dae7abafd757 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -15,7 +15,8 @@ exportMethods("glm", "predict", "summary", "kmeans", - "fitted") + "fitted", + "survreg") # Job group lifecycle management methods export("setJobGroup", diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index 6ad71fcb4671..25062b82705a 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -1175,3 +1175,7 @@ setGeneric("kmeans") #' @rdname fitted #' @export setGeneric("fitted") + +#' @rdname survreg +#' @export +setGeneric("survreg", function(formula, data) { standardGeneric("survreg") }) diff --git a/R/pkg/R/mllib.R b/R/pkg/R/mllib.R index 5c0d3dcf3af9..5d6e55131a7f 100644 --- a/R/pkg/R/mllib.R +++ b/R/pkg/R/mllib.R @@ -61,6 +61,34 @@ setMethod("glm", signature(formula = "formula", family = "ANY", data = "DataFram return(new("PipelineModel", model = model)) }) +#' Fit an accelerated failure time (AFT) survival regression model. +#' +#' Fit an accelerated failure time (AFT) survival regression model, similarly to R's survreg(). +#' +#' @param formula A symbolic description of the model to be fitted. Currently only a few formula +#' operators are supported, including '~', ':', '+', and '-'. +#' @param data DataFrame for training. +#' @return a fitted MLlib model +#' @rdname survreg +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlContext <- sparkRSQL.init(sc) +#' library(survival) +#' data(ovarian) +#' df <- createDataFrame(sqlContext, ovarian) +#' model <- survreg(Surv(futime, fustat) ~ ecog_ps + rx, df) +#' summary(model) +#'} +setMethod("survreg", signature(formula = "formula", data = "DataFrame"), + function(formula, data) { + formula <- paste(deparse(formula), collapse="") + model <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers", + "fitAFTSurvivalRegression", formula, data@sdf) + return(new("PipelineModel", model = model)) + }) + #' Make predictions from a model #' #' Makes predictions from a model produced by glm(), similarly to R's predict(). @@ -135,6 +163,11 @@ setMethod("summary", signature(object = "PipelineModel"), colnames(coefficients) <- unlist(features) rownames(coefficients) <- 1:k return(list(coefficients = coefficients, size = size, cluster = dataFrame(cluster))) + } else if (modelName == "AFTSurvivalRegressionModel") { + coefficients <- as.matrix(unlist(coefficients)) + colnames(coefficients) <- c("Value") + rownames(coefficients) <- unlist(features) + return(list(coefficients = coefficients)) } else { stop(paste("Unsupported model", modelName, sep = " ")) } diff --git a/R/pkg/inst/tests/testthat/test_mllib.R b/R/pkg/inst/tests/testthat/test_mllib.R index e120462964d1..0a625f059ae0 100644 --- a/R/pkg/inst/tests/testthat/test_mllib.R +++ b/R/pkg/inst/tests/testthat/test_mllib.R @@ -141,3 +141,17 @@ test_that("kmeans", { cluster <- summary.model$cluster expect_equal(sort(collect(distinct(select(cluster, "prediction")))$prediction), c(0, 1)) }) + +test_that("survreg vs survival::survreg", { + data <- list(list(4,1,0,0), list(3,1,2,0), list(1,1,1,0), + list(1,0,1,0), list(2,1,1,1), list(2,1,0,1), list(3,0,0,1)) + df <- createDataFrame(sqlContext, data, c("time", "status", "x", "sex")) + model <- survreg(Surv(time, status) ~ x + sex, df) + stats <- summary(model) + coefs <- as.vector(stats$coefficients[,1]) + rCoefs <- c(1.3149571, -0.1903409, -0.2532618, -1.1599802) + expect_true(all(abs(rCoefs - coefs) < 1e-4)) + expect_true(all( + rownames(stats$coefficients) == + c("(Intercept)", "x", "sex", "Log(scale)"))) +}) diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala b/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala index d23e4fc9d1f5..7dbb5ed3dd12 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala @@ -17,12 +17,13 @@ package org.apache.spark.ml.api.r +import org.apache.spark.SparkException import org.apache.spark.ml.{Pipeline, PipelineModel} import org.apache.spark.ml.attribute._ import org.apache.spark.ml.classification.{LogisticRegression, LogisticRegressionModel} import org.apache.spark.ml.clustering.{KMeans, KMeansModel} import org.apache.spark.ml.feature.{RFormula, VectorAssembler} -import org.apache.spark.ml.regression.{LinearRegression, LinearRegressionModel} +import org.apache.spark.ml.regression._ import org.apache.spark.sql.DataFrame private[r] object SparkRWrappers { @@ -52,6 +53,43 @@ private[r] object SparkRWrappers { pipeline.fit(df) } + def fitAFTSurvivalRegression( + value: String, + df: DataFrame): PipelineModel = { + + def formulaRewrite(value: String): (String, String) = { + var rewrited: String = null + var censorCol: String = null + + val regex = "^Surv\\(([^,]*),([^,]*)\\)\\s*\\~\\s*(.*)".r + try { + val regex(label, censor, features) = value + // TODO: Support dot operator. + if (features.contains(".")) { + throw new UnsupportedOperationException( + "Terms of survreg formula can not support dot operator.") + } + rewrited = label.trim + "~" + features + censorCol = censor.trim + } catch { + case e: MatchError => + throw new SparkException(s"Could not parse formula: $value") + } + + (rewrited, censorCol) + } + + val (rewritedValue, censorCol) = formulaRewrite(value) + + val formula = new RFormula().setFormula(rewritedValue) + val estimator = new AFTSurvivalRegression() + .setCensorCol(censorCol) + .setFitIntercept(formula.hasIntercept) + + val pipeline = new Pipeline().setStages(Array(formula, estimator)) + pipeline.fit(df) + } + def fitKMeans( df: DataFrame, initMode: String, @@ -91,6 +129,12 @@ private[r] object SparkRWrappers { } case m: KMeansModel => m.clusterCenters.flatMap(_.toArray) + case m: AFTSurvivalRegressionModel => + if (m.getFitIntercept) { + Array(m.intercept) ++ m.coefficients.toArray ++ Array(math.log(m.scale)) + } else { + m.coefficients.toArray ++ Array(math.log(m.scale)) + } } } @@ -151,6 +195,14 @@ private[r] object SparkRWrappers { val attrs = AttributeGroup.fromStructField( m.summary.predictions.schema(m.summary.featuresCol)) attrs.attributes.get.map(_.name.get) + case m: AFTSurvivalRegressionModel => + val attrs = AttributeGroup.fromStructField( + m.summary.predictions.schema(m.summary.featuresCol)) + if (m.getFitIntercept) { + Array("(Intercept)") ++ attrs.attributes.get.map(_.name.get) ++ Array("Log(scale)") + } else { + attrs.attributes.get.map(_.name.get) ++ Array("Log(scale)") + } } } @@ -162,6 +214,8 @@ private[r] object SparkRWrappers { "LogisticRegressionModel" case m: KMeansModel => "KMeansModel" + case m: AFTSurvivalRegressionModel => + "AFTSurvivalRegressionModel" } } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala index ba5708ab8d9b..fc67a90a8af3 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala @@ -232,8 +232,12 @@ class AFTSurvivalRegression @Since("1.6.0") (@Since("1.6.0") override val uid: S val coefficients = Vectors.dense(parameters.slice(2, parameters.length)) val intercept = parameters(1) val scale = math.exp(parameters(0)) - val model = new AFTSurvivalRegressionModel(uid, coefficients, intercept, scale) - copyValues(model.setParent(this)) + val model = copyValues( + new AFTSurvivalRegressionModel(uid, coefficients, intercept, scale) + .setParent(this)) + val summary = new AFTSurvivalRegressionSummary(model.transform(dataset), + $(predictionCol), $(labelCol), $(featuresCol)) + model.setSummary(summary) } @Since("1.6.0") @@ -281,6 +285,26 @@ class AFTSurvivalRegressionModel private[ml] ( @Since("1.6.0") def setQuantilesCol(value: String): this.type = set(quantilesCol, value) + private var trainingSummary: Option[AFTSurvivalRegressionSummary] = None + + private[regression] def setSummary(summary: AFTSurvivalRegressionSummary): this.type = { + this.trainingSummary = Some(summary) + this + } + + /** + * Gets summary of model on training set. An exception is + * thrown if `trainingSummary == None`. + */ + @Since("2.0.0") + def summary: AFTSurvivalRegressionSummary = trainingSummary match { + case Some(summ) => summ + case None => + throw new SparkException( + "No training summary available for this AFTSurvivalRegressionModel", + new NullPointerException()) + } + @Since("1.6.0") def predictQuantiles(features: Vector): Vector = { // scale parameter for the Weibull distribution of lifetime @@ -375,6 +399,23 @@ object AFTSurvivalRegressionModel extends MLReadable[AFTSurvivalRegressionModel] } } +/** + * :: Experimental :: + * AFT survival regression results evaluated on a dataset. + * + * @param predictions dataframe outputted by the model's `transform` method. + * @param predictionCol field in "predictions" which gives the prediction of each instance. + * @param labelCol field in "predictions" which gives the true label of each instance. + * @param featuresCol field in "predictions" which gives the features of each instance as a vector. + */ +@Experimental +@Since("2.0.0") +class AFTSurvivalRegressionSummary private[regression] ( + @Since("2.0.0") @transient val predictions: DataFrame, + @Since("2.0.0") val predictionCol: String, + @Since("2.0.0") val labelCol: String, + @Since("2.0.0") val featuresCol: String) extends Serializable + /** * AFTAggregator computes the gradient and loss for a AFT loss function, * as used in AFT survival regression for samples in sparse or dense vector in a online fashion. From 96fe3c2c0dd674967d71665b2e23c1540c803ae0 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Wed, 2 Mar 2016 18:26:59 +0800 Subject: [PATCH 2/6] Fix regex --- mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala b/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala index 7dbb5ed3dd12..9038b14ff1e6 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala @@ -61,7 +61,7 @@ private[r] object SparkRWrappers { var rewrited: String = null var censorCol: String = null - val regex = "^Surv\\(([^,]*),([^,]*)\\)\\s*\\~\\s*(.*)".r + val regex = "^Surv\\(([^,]+),([^,]+)\\)\\s*\\~\\s*(.+)".r try { val regex(label, censor, features) = value // TODO: Support dot operator. From 4ccb1721dff8fc8005bf2d8b822c5e9bbe06e9b3 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Thu, 17 Mar 2016 21:29:48 +0800 Subject: [PATCH 3/6] Fix regex & handle invalid prediction column --- .../apache/spark/ml/r/SparkRWrappers.scala | 4 +- .../ml/regression/AFTSurvivalRegression.scala | 37 ++++++++++++------- 2 files changed, 26 insertions(+), 15 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala b/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala index 9038b14ff1e6..ec48ed96bcb3 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala @@ -61,7 +61,7 @@ private[r] object SparkRWrappers { var rewrited: String = null var censorCol: String = null - val regex = "^Surv\\(([^,]+),([^,]+)\\)\\s*\\~\\s*(.+)".r + val regex = "^Surv\\s*\\(([^,]+),([^,]+)\\)\\s*\\~\\s*(.+)".r try { val regex(label, censor, features) = value // TODO: Support dot operator. @@ -197,7 +197,7 @@ private[r] object SparkRWrappers { attrs.attributes.get.map(_.name.get) case m: AFTSurvivalRegressionModel => val attrs = AttributeGroup.fromStructField( - m.summary.predictions.schema(m.summary.featuresCol)) + m.summary.predictions.schema(m.getFeaturesCol)) if (m.getFitIntercept) { Array("(Intercept)") ++ attrs.attributes.get.map(_.name.get) ++ Array("Log(scale)") } else { diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala index fc67a90a8af3..fd4e1b692ddd 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala @@ -235,8 +235,10 @@ class AFTSurvivalRegression @Since("1.6.0") (@Since("1.6.0") override val uid: S val model = copyValues( new AFTSurvivalRegressionModel(uid, coefficients, intercept, scale) .setParent(this)) - val summary = new AFTSurvivalRegressionSummary(model.transform(dataset), - $(predictionCol), $(labelCol), $(featuresCol)) + // Handle possible missing or invalid prediction columns + val (summaryModel, predictionColName) = model.findSummaryModelAndPredictionCol() + val summary = new AFTSurvivalRegressionSummary( + summaryModel.transform(dataset), predictionColName) model.setSummary(summary) } @@ -292,17 +294,30 @@ class AFTSurvivalRegressionModel private[ml] ( this } + /** + * If the prediction column is set returns the current model and prediction column, + * otherwise generates a new column and sets it as the prediction column on a new copy + * of the current model. + */ + private[regression] def findSummaryModelAndPredictionCol() + : (AFTSurvivalRegressionModel, String) = { + $(predictionCol) match { + case "" => + val predictionColName = "prediction_" + java.util.UUID.randomUUID.toString() + (copy(ParamMap.empty).setPredictionCol(predictionColName), predictionColName) + case p => (this, p) + } + } + /** * Gets summary of model on training set. An exception is * thrown if `trainingSummary == None`. */ @Since("2.0.0") - def summary: AFTSurvivalRegressionSummary = trainingSummary match { - case Some(summ) => summ - case None => - throw new SparkException( - "No training summary available for this AFTSurvivalRegressionModel", - new NullPointerException()) + def summary: AFTSurvivalRegressionSummary = trainingSummary.getOrElse { + throw new SparkException( + "No training summary available for this AFTSurvivalRegressionModel", + new RuntimeException()) } @Since("1.6.0") @@ -405,16 +420,12 @@ object AFTSurvivalRegressionModel extends MLReadable[AFTSurvivalRegressionModel] * * @param predictions dataframe outputted by the model's `transform` method. * @param predictionCol field in "predictions" which gives the prediction of each instance. - * @param labelCol field in "predictions" which gives the true label of each instance. - * @param featuresCol field in "predictions" which gives the features of each instance as a vector. */ @Experimental @Since("2.0.0") class AFTSurvivalRegressionSummary private[regression] ( @Since("2.0.0") @transient val predictions: DataFrame, - @Since("2.0.0") val predictionCol: String, - @Since("2.0.0") val labelCol: String, - @Since("2.0.0") val featuresCol: String) extends Serializable + @Since("2.0.0") val predictionCol: String) extends Serializable /** * AFTAggregator computes the gradient and loss for a AFT loss function, From e7605274fe029645eeac8443de8363b039462733 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Thu, 17 Mar 2016 22:16:15 +0800 Subject: [PATCH 4/6] fix typos --- R/pkg/inst/tests/testthat/test_mllib.R | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/R/pkg/inst/tests/testthat/test_mllib.R b/R/pkg/inst/tests/testthat/test_mllib.R index 0a625f059ae0..7c1cbafc7fff 100644 --- a/R/pkg/inst/tests/testthat/test_mllib.R +++ b/R/pkg/inst/tests/testthat/test_mllib.R @@ -143,12 +143,12 @@ test_that("kmeans", { }) test_that("survreg vs survival::survreg", { - data <- list(list(4,1,0,0), list(3,1,2,0), list(1,1,1,0), - list(1,0,1,0), list(2,1,1,1), list(2,1,0,1), list(3,0,0,1)) + data <- list(list(4, 1, 0, 0), list(3, 1, 2, 0), list(1, 1, 1, 0), + list(1, 0, 1, 0), list(2, 1, 1, 1), list(2, 1, 0, 1), list(3, 0, 0, 1)) df <- createDataFrame(sqlContext, data, c("time", "status", "x", "sex")) model <- survreg(Surv(time, status) ~ x + sex, df) stats <- summary(model) - coefs <- as.vector(stats$coefficients[,1]) + coefs <- as.vector(stats$coefficients[, 1]) rCoefs <- c(1.3149571, -0.1903409, -0.2532618, -1.1599802) expect_true(all(abs(rCoefs - coefs) < 1e-4)) expect_true(all( From dbc1077c09713c7c3141dd40251ecaf9a0e74274 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Thu, 17 Mar 2016 22:23:49 +0800 Subject: [PATCH 5/6] fix typos --- R/pkg/R/mllib.R | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/R/pkg/R/mllib.R b/R/pkg/R/mllib.R index 5d6e55131a7f..b889fd76ee6f 100644 --- a/R/pkg/R/mllib.R +++ b/R/pkg/R/mllib.R @@ -83,7 +83,7 @@ setMethod("glm", signature(formula = "formula", family = "ANY", data = "DataFram #'} setMethod("survreg", signature(formula = "formula", data = "DataFrame"), function(formula, data) { - formula <- paste(deparse(formula), collapse="") + formula <- paste(deparse(formula), collapse = "") model <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers", "fitAFTSurvivalRegression", formula, data@sdf) return(new("PipelineModel", model = model)) From 900c85fb6a45d6bccca421fdfd436ced7988581d Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Fri, 18 Mar 2016 18:16:07 +0800 Subject: [PATCH 6/6] update test case for survreg --- R/pkg/DESCRIPTION | 3 ++- R/pkg/inst/tests/testthat/test_mllib.R | 23 +++++++++++++++-------- 2 files changed, 17 insertions(+), 9 deletions(-) diff --git a/R/pkg/DESCRIPTION b/R/pkg/DESCRIPTION index 0cd0d75df0f7..6d5b2b931cb0 100644 --- a/R/pkg/DESCRIPTION +++ b/R/pkg/DESCRIPTION @@ -11,7 +11,8 @@ Depends: R (>= 3.0), methods, Suggests: - testthat + testthat, + survival Description: R frontend for Spark License: Apache License (== 2.0) Collate: diff --git a/R/pkg/inst/tests/testthat/test_mllib.R b/R/pkg/inst/tests/testthat/test_mllib.R index 7c1cbafc7fff..b1c45b3e3ab6 100644 --- a/R/pkg/inst/tests/testthat/test_mllib.R +++ b/R/pkg/inst/tests/testthat/test_mllib.R @@ -142,16 +142,23 @@ test_that("kmeans", { expect_equal(sort(collect(distinct(select(cluster, "prediction")))$prediction), c(0, 1)) }) -test_that("survreg vs survival::survreg", { - data <- list(list(4, 1, 0, 0), list(3, 1, 2, 0), list(1, 1, 1, 0), - list(1, 0, 1, 0), list(2, 1, 1, 1), list(2, 1, 0, 1), list(3, 0, 0, 1)) - df <- createDataFrame(sqlContext, data, c("time", "status", "x", "sex")) - model <- survreg(Surv(time, status) ~ x + sex, df) +test_that("SparkR::survreg vs survival::survreg", { + library(survival) + data(ovarian) + df <- suppressWarnings(createDataFrame(sqlContext, ovarian)) + + model <- SparkR::survreg(Surv(futime, fustat) ~ ecog_ps + rx, df) stats <- summary(model) - coefs <- as.vector(stats$coefficients[, 1]) - rCoefs <- c(1.3149571, -0.1903409, -0.2532618, -1.1599802) + coefs <- as.vector(stats$coefficients[, 1][1:3]) + scale <- exp(stats$coefficients[, 1][4]) + + rModel <- survival::survreg(Surv(futime, fustat) ~ ecog.ps + rx, ovarian) + rCoefs <- as.vector(coef(rModel)) + rScale <- rModel$scale + expect_true(all(abs(rCoefs - coefs) < 1e-4)) + expect_true(abs(rScale - scale) < 1e-4) expect_true(all( rownames(stats$coefficients) == - c("(Intercept)", "x", "sex", "Log(scale)"))) + c("(Intercept)", "ecog_ps", "rx", "Log(scale)"))) })