Skip to content

Commit 13cbb2d

Browse files
yanboliangmengxr
authored andcommitted
[SPARK-13010][ML][SPARKR] Implement a simple wrapper of AFTSurvivalRegression in SparkR
## What changes were proposed in this pull request? This PR continues the work in apache#11447, we implemented the wrapper of ```AFTSurvivalRegression``` named ```survreg``` in SparkR. ## How was this patch tested? Test against output from R package survival's survreg. cc mengxr felixcheung Close apache#11447 Author: Yanbo Liang <[email protected]> Closes apache#11932 from yanboliang/spark-13010-new.
1 parent 05f652d commit 13cbb2d

File tree

6 files changed

+231
-2
lines changed

6 files changed

+231
-2
lines changed

R/pkg/DESCRIPTION

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,8 @@ Depends:
1212
methods,
1313
Suggests:
1414
testthat,
15-
e1071
15+
e1071,
16+
survival
1617
Description: R frontend for Spark
1718
License: Apache License (== 2.0)
1819
Collate:

R/pkg/NAMESPACE

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,8 @@ exportMethods("glm",
1616
"summary",
1717
"kmeans",
1818
"fitted",
19-
"naiveBayes")
19+
"naiveBayes",
20+
"survreg")
2021

2122
# Job group lifecycle management methods
2223
export("setJobGroup",

R/pkg/R/generics.R

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1179,3 +1179,7 @@ setGeneric("fitted")
11791179
#' @rdname naiveBayes
11801180
#' @export
11811181
setGeneric("naiveBayes", function(formula, data, ...) { standardGeneric("naiveBayes") })
1182+
1183+
#' @rdname survreg
1184+
#' @export
1185+
setGeneric("survreg", function(formula, data, ...) { standardGeneric("survreg") })

R/pkg/R/mllib.R

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,11 @@ setClass("PipelineModel", representation(model = "jobj"))
2727
#' @export
2828
setClass("NaiveBayesModel", representation(jobj = "jobj"))
2929

30+
#' @title S4 class that represents a AFTSurvivalRegressionModel
31+
#' @param jobj a Java object reference to the backing Scala AFTSurvivalRegressionWrapper
32+
#' @export
33+
setClass("AFTSurvivalRegressionModel", representation(jobj = "jobj"))
34+
3035
#' Fits a generalized linear model
3136
#'
3237
#' Fits a generalized linear model, similarly to R's glm(). Also see the glmnet package.
@@ -273,3 +278,73 @@ setMethod("naiveBayes", signature(formula = "formula", data = "DataFrame"),
273278
formula, data@sdf, laplace)
274279
return(new("NaiveBayesModel", jobj = jobj))
275280
})
281+
282+
#' Fit an accelerated failure time (AFT) survival regression model.
283+
#'
284+
#' Fit an accelerated failure time (AFT) survival regression model, similarly to R's survreg().
285+
#'
286+
#' @param formula A symbolic description of the model to be fitted. Currently only a few formula
287+
#' operators are supported, including '~', ':', '+', and '-'.
288+
#' Note that operator '.' is not supported currently.
289+
#' @param data DataFrame for training.
290+
#' @return a fitted AFT survival regression model
291+
#' @rdname survreg
292+
#' @seealso survival: \url{https://cran.r-project.org/web/packages/survival/}
293+
#' @export
294+
#' @examples
295+
#' \dontrun{
296+
#' df <- createDataFrame(sqlContext, ovarian)
297+
#' model <- survreg(Surv(futime, fustat) ~ ecog_ps + rx, df)
298+
#' }
299+
setMethod("survreg", signature(formula = "formula", data = "DataFrame"),
300+
function(formula, data, ...) {
301+
formula <- paste(deparse(formula), collapse = "")
302+
jobj <- callJStatic("org.apache.spark.ml.r.AFTSurvivalRegressionWrapper",
303+
"fit", formula, data@sdf)
304+
return(new("AFTSurvivalRegressionModel", jobj = jobj))
305+
})
306+
307+
#' Get the summary of an AFT survival regression model
308+
#'
309+
#' Returns the summary of an AFT survival regression model produced by survreg(),
310+
#' similarly to R's summary().
311+
#'
312+
#' @param object a fitted AFT survival regression model
313+
#' @return coefficients the model's coefficients, intercept and log(scale).
314+
#' @rdname summary
315+
#' @export
316+
#' @examples
317+
#' \dontrun{
318+
#' model <- survreg(Surv(futime, fustat) ~ ecog_ps + rx, trainingData)
319+
#' summary(model)
320+
#' }
321+
setMethod("summary", signature(object = "AFTSurvivalRegressionModel"),
322+
function(object, ...) {
323+
jobj <- object@jobj
324+
features <- callJMethod(jobj, "rFeatures")
325+
coefficients <- callJMethod(jobj, "rCoefficients")
326+
coefficients <- as.matrix(unlist(coefficients))
327+
colnames(coefficients) <- c("Value")
328+
rownames(coefficients) <- unlist(features)
329+
return(list(coefficients = coefficients))
330+
})
331+
332+
#' Make predictions from an AFT survival regression model
333+
#'
334+
#' Make predictions from a model produced by survreg(), similarly to R package survival's predict.
335+
#'
336+
#' @param object A fitted AFT survival regression model
337+
#' @param newData DataFrame for testing
338+
#' @return DataFrame containing predicted labels in a column named "prediction"
339+
#' @rdname predict
340+
#' @export
341+
#' @examples
342+
#' \dontrun{
343+
#' model <- survreg(Surv(futime, fustat) ~ ecog_ps + rx, trainingData)
344+
#' predicted <- predict(model, testData)
345+
#' showDF(predicted)
346+
#' }
347+
setMethod("predict", signature(object = "AFTSurvivalRegressionModel"),
348+
function(object, newData) {
349+
return(dataFrame(callJMethod(object@jobj, "transform", newData@sdf)))
350+
})

R/pkg/inst/tests/testthat/test_mllib.R

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -200,3 +200,52 @@ test_that("naiveBayes", {
200200
expect_equal(as.character(predict(m, t1[1, ])), "Yes")
201201
}
202202
})
203+
204+
test_that("survreg", {
205+
# R code to reproduce the result.
206+
#
207+
#' rData <- list(time = c(4, 3, 1, 1, 2, 2, 3), status = c(1, 1, 1, 0, 1, 1, 0),
208+
#' x = c(0, 2, 1, 1, 1, 0, 0), sex = c(0, 0, 0, 0, 1, 1, 1))
209+
#' library(survival)
210+
#' model <- survreg(Surv(time, status) ~ x + sex, rData)
211+
#' summary(model)
212+
#' predict(model, data)
213+
#
214+
# -- output of 'summary(model)'
215+
#
216+
# Value Std. Error z p
217+
# (Intercept) 1.315 0.270 4.88 1.07e-06
218+
# x -0.190 0.173 -1.10 2.72e-01
219+
# sex -0.253 0.329 -0.77 4.42e-01
220+
# Log(scale) -1.160 0.396 -2.93 3.41e-03
221+
#
222+
# -- output of 'predict(model, data)'
223+
#
224+
# 1 2 3 4 5 6 7
225+
# 3.724591 2.545368 3.079035 3.079035 2.390146 2.891269 2.891269
226+
#
227+
data <- list(list(4, 1, 0, 0), list(3, 1, 2, 0), list(1, 1, 1, 0),
228+
list(1, 0, 1, 0), list(2, 1, 1, 1), list(2, 1, 0, 1), list(3, 0, 0, 1))
229+
df <- createDataFrame(sqlContext, data, c("time", "status", "x", "sex"))
230+
model <- survreg(Surv(time, status) ~ x + sex, df)
231+
stats <- summary(model)
232+
coefs <- as.vector(stats$coefficients[, 1])
233+
rCoefs <- c(1.3149571, -0.1903409, -0.2532618, -1.1599800)
234+
expect_equal(coefs, rCoefs, tolerance = 1e-4)
235+
expect_true(all(
236+
rownames(stats$coefficients) ==
237+
c("(Intercept)", "x", "sex", "Log(scale)")))
238+
p <- collect(select(predict(model, df), "prediction"))
239+
expect_equal(p$prediction, c(3.724591, 2.545368, 3.079035, 3.079035,
240+
2.390146, 2.891269, 2.891269), tolerance = 1e-4)
241+
242+
# Test survival::survreg
243+
if (requireNamespace("survival", quietly = TRUE)) {
244+
rData <- list(time = c(4, 3, 1, 1, 2, 2, 3), status = c(1, 1, 1, 0, 1, 1, 0),
245+
x = c(0, 2, 1, 1, 1, 0, 0), sex = c(0, 0, 0, 0, 1, 1, 1))
246+
expect_that(
247+
model <- survival::survreg(formula = survival::Surv(time, status) ~ x + sex, data = rData),
248+
not(throws_error()))
249+
expect_equal(predict(model, rData)[[1]], 3.724591, tolerance = 1e-4)
250+
}
251+
})
Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.ml.r
19+
20+
import org.apache.spark.SparkException
21+
import org.apache.spark.ml.{Pipeline, PipelineModel}
22+
import org.apache.spark.ml.attribute.AttributeGroup
23+
import org.apache.spark.ml.feature.RFormula
24+
import org.apache.spark.ml.regression.{AFTSurvivalRegression, AFTSurvivalRegressionModel}
25+
import org.apache.spark.sql.DataFrame
26+
27+
private[r] class AFTSurvivalRegressionWrapper private (
28+
pipeline: PipelineModel,
29+
features: Array[String]) {
30+
31+
private val aftModel: AFTSurvivalRegressionModel =
32+
pipeline.stages(1).asInstanceOf[AFTSurvivalRegressionModel]
33+
34+
lazy val rCoefficients: Array[Double] = if (aftModel.getFitIntercept) {
35+
Array(aftModel.intercept) ++ aftModel.coefficients.toArray ++ Array(math.log(aftModel.scale))
36+
} else {
37+
aftModel.coefficients.toArray ++ Array(math.log(aftModel.scale))
38+
}
39+
40+
lazy val rFeatures: Array[String] = if (aftModel.getFitIntercept) {
41+
Array("(Intercept)") ++ features ++ Array("Log(scale)")
42+
} else {
43+
features ++ Array("Log(scale)")
44+
}
45+
46+
def transform(dataset: DataFrame): DataFrame = {
47+
pipeline.transform(dataset)
48+
}
49+
}
50+
51+
private[r] object AFTSurvivalRegressionWrapper {
52+
53+
private def formulaRewrite(formula: String): (String, String) = {
54+
var rewritedFormula: String = null
55+
var censorCol: String = null
56+
57+
val regex = """Surv\(([^,]+), ([^,]+)\) ~ (.+)""".r
58+
try {
59+
val regex(label, censor, features) = formula
60+
// TODO: Support dot operator.
61+
if (features.contains(".")) {
62+
throw new UnsupportedOperationException(
63+
"Terms of survreg formula can not support dot operator.")
64+
}
65+
rewritedFormula = label.trim + "~" + features.trim
66+
censorCol = censor.trim
67+
} catch {
68+
case e: MatchError =>
69+
throw new SparkException(s"Could not parse formula: $formula")
70+
}
71+
72+
(rewritedFormula, censorCol)
73+
}
74+
75+
76+
def fit(formula: String, data: DataFrame): AFTSurvivalRegressionWrapper = {
77+
78+
val (rewritedFormula, censorCol) = formulaRewrite(formula)
79+
80+
val rFormula = new RFormula().setFormula(rewritedFormula)
81+
val rFormulaModel = rFormula.fit(data)
82+
83+
// get feature names from output schema
84+
val schema = rFormulaModel.transform(data).schema
85+
val featureAttrs = AttributeGroup.fromStructField(schema(rFormula.getFeaturesCol))
86+
.attributes.get
87+
val features = featureAttrs.map(_.name.get)
88+
89+
val aft = new AFTSurvivalRegression()
90+
.setCensorCol(censorCol)
91+
.setFitIntercept(rFormula.hasIntercept)
92+
93+
val pipeline = new Pipeline()
94+
.setStages(Array(rFormulaModel, aft))
95+
.fit(data)
96+
97+
new AFTSurvivalRegressionWrapper(pipeline, features)
98+
}
99+
}

0 commit comments

Comments
 (0)