Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import org.apache.hadoop.fs.Path

import org.apache.spark.SparkException
import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.internal.Logging
import org.apache.spark.ml.{Estimator, Model}
import org.apache.spark.ml.linalg.{BLAS, Vector, Vectors, VectorUDT}
Expand Down Expand Up @@ -219,7 +220,9 @@ class AFTSurvivalRegression @Since("1.6.0") (@Since("1.6.0") override val uid: S
"columns. This behavior is different from R survival::survreg.")
}

val costFun = new AFTCostFun(instances, $(fitIntercept), featuresStd)
val bcFeaturesStd = instances.context.broadcast(featuresStd)

val costFun = new AFTCostFun(instances, $(fitIntercept), bcFeaturesStd)
val optimizer = new BreezeLBFGS[BDV[Double]]($(maxIter), 10, $(tol))

/*
Expand Down Expand Up @@ -247,6 +250,7 @@ class AFTSurvivalRegression @Since("1.6.0") (@Since("1.6.0") override val uid: S
state.x.toArray.clone()
}

bcFeaturesStd.destroy(blocking = false)
if (handlePersistence) instances.unpersist()

val rawCoefficients = parameters.slice(2, parameters.length)
Expand Down Expand Up @@ -478,26 +482,29 @@ object AFTSurvivalRegressionModel extends MLReadable[AFTSurvivalRegressionModel]
* $$
* </blockquote></p>
*
* @param parameters including three part: The log of scale parameter, the intercept and
* regression coefficients corresponding to the features.
* @param bcParameters The broadcasted value includes three part: The log of scale parameter,
* the intercept and regression coefficients corresponding to the features.
* @param fitIntercept Whether to fit an intercept term.
* @param featuresStd The standard deviation values of the features.
* @param bcFeaturesStd The broadcast standard deviation values of the features.
*/
private class AFTAggregator(
parameters: BDV[Double],
bcParameters: Broadcast[BDV[Double]],
fitIntercept: Boolean,
featuresStd: Array[Double]) extends Serializable {
bcFeaturesStd: Broadcast[Array[Double]]) extends Serializable {

private val length = bcParameters.value.length
// make transient so we do not serialize between aggregation stages
@transient private lazy val parameters = bcParameters.value
// the regression coefficients to the covariates
private val coefficients = parameters.slice(2, parameters.length)
private val intercept = parameters(1)
@transient private lazy val coefficients = parameters.slice(2, length)
@transient private lazy val intercept = parameters(1)
// sigma is the scale parameter of the AFT model
private val sigma = math.exp(parameters(0))
@transient private lazy val sigma = math.exp(parameters(0))

Copy link
Member

@dbtsai dbtsai Aug 8, 2016

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In line 506,

  private val gradientSumArray = Array.ofDim[Double](parameters.length)

the code will evaluate the lazy parameters in the driver.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BTW, after thinking a bit, some of the lazy is not needed. lazy is for avoiding doing computation in the driver; however
@transient private val parameters = bcParameters.value should work without lazy. Also, sigma or intercept may not need lazy. Thanks.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@dbtsai I addressed the parameters.length issue. But I can not remove lazy from @transient private lazy val parameters = bcParameters.value and intercept/sigma. Otherwise, it complains NullPointerException. If I removed both @transient and lazy, it works well, but this does not coincide with our requirements. It's a little weird and I'm still work on to figure out the root cause, can you give me some suggestion? Thanks.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are right. In scala, when we use @transient private val, that lazy evaluation will be only evaluated once even after serialization/deserialization cycle. As a result, after the AFTAggregator is broadcasted into executors, the variable will be be evaluated again, and will be default to null.

private var totalCnt: Long = 0L
private var lossSum = 0.0
// Here we optimize loss function over log(sigma), intercept and coefficients
private val gradientSumArray = Array.ofDim[Double](parameters.length)
private val gradientSumArray = Array.ofDim[Double](length)

def count: Long = totalCnt
def loss: Double = {
Expand All @@ -524,11 +531,13 @@ private class AFTAggregator(
val ti = data.label
val delta = data.censor

val localFeaturesStd = bcFeaturesStd.value

val margin = {
var sum = 0.0
xi.foreachActive { (index, value) =>
if (featuresStd(index) != 0.0 && value != 0.0) {
sum += coefficients(index) * (value / featuresStd(index))
if (localFeaturesStd(index) != 0.0 && value != 0.0) {
sum += coefficients(index) * (value / localFeaturesStd(index))
}
}
sum + intercept
Expand All @@ -542,8 +551,8 @@ private class AFTAggregator(
gradientSumArray(0) += delta + multiplier * sigma * epsilon
gradientSumArray(1) += { if (fitIntercept) multiplier else 0.0 }
xi.foreachActive { (index, value) =>
if (featuresStd(index) != 0.0 && value != 0.0) {
gradientSumArray(index + 2) += multiplier * (value / featuresStd(index))
if (localFeaturesStd(index) != 0.0 && value != 0.0) {
gradientSumArray(index + 2) += multiplier * (value / localFeaturesStd(index))
}
}

Expand All @@ -565,8 +574,7 @@ private class AFTAggregator(
lossSum += other.lossSum

var i = 0
val len = this.gradientSumArray.length
while (i < len) {
while (i < length) {
this.gradientSumArray(i) += other.gradientSumArray(i)
i += 1
}
Expand All @@ -583,19 +591,22 @@ private class AFTAggregator(
private class AFTCostFun(
data: RDD[AFTPoint],
fitIntercept: Boolean,
featuresStd: Array[Double]) extends DiffFunction[BDV[Double]] {
bcFeaturesStd: Broadcast[Array[Double]]) extends DiffFunction[BDV[Double]] {

override def calculate(parameters: BDV[Double]): (Double, BDV[Double]) = {

val bcParameters = data.context.broadcast(parameters)

val aftAggregator = data.treeAggregate(
new AFTAggregator(parameters, fitIntercept, featuresStd))(
new AFTAggregator(bcParameters, fitIntercept, bcFeaturesStd))(
seqOp = (c, v) => (c, v) match {
case (aggregator, instance) => aggregator.add(instance)
},
combOp = (c1, c2) => (c1, c2) match {
case (aggregator1, aggregator2) => aggregator1.merge(aggregator2)
})

bcParameters.destroy(blocking = false)
(aftAggregator.loss, aftAggregator.gradient)
}
}
Expand Down