Skip to content

Commit d6caa02

Browse files
committed
update setSummary for other algos
1 parent 428348d commit d6caa02

File tree

8 files changed

+30
-24
lines changed

8 files changed

+30
-24
lines changed

mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -661,7 +661,7 @@ class LogisticRegression @Since("1.2.0") (
661661
$(labelCol),
662662
$(featuresCol),
663663
objectiveHistory)
664-
model.setSummary(logRegSummary)
664+
model.setSummary(Some(logRegSummary))
665665
} else {
666666
model
667667
}
@@ -803,9 +803,9 @@ class LogisticRegressionModel private[spark] (
803803
}
804804
}
805805

806-
private[classification] def setSummary(
807-
summary: LogisticRegressionTrainingSummary): this.type = {
808-
this.trainingSummary = Some(summary)
806+
private[classification]
807+
def setSummary(summary: Option[LogisticRegressionTrainingSummary]): this.type = {
808+
this.trainingSummary = summary
809809
this
810810
}
811811

@@ -900,8 +900,7 @@ class LogisticRegressionModel private[spark] (
900900
override def copy(extra: ParamMap): LogisticRegressionModel = {
901901
val newModel = copyValues(new LogisticRegressionModel(uid, coefficientMatrix, interceptVector,
902902
numClasses, isMultinomial), extra)
903-
if (trainingSummary.isDefined) newModel.setSummary(trainingSummary.get)
904-
newModel.setParent(parent)
903+
newModel.setSummary(trainingSummary).setParent(parent)
905904
}
906905

907906
override protected def raw2prediction(rawPrediction: Vector): Double = {

mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -110,8 +110,7 @@ class KMeansModel private[ml] (
110110
@Since("1.5.0")
111111
override def copy(extra: ParamMap): KMeansModel = {
112112
val copied = copyValues(new KMeansModel(uid, parentModel), extra)
113-
if (trainingSummary.isDefined) copied.setSummary(trainingSummary.get)
114-
copied.setParent(this.parent)
113+
copied.setSummary(trainingSummary).setParent(this.parent)
115114
}
116115

117116
/** @group setParam */
@@ -165,8 +164,8 @@ class KMeansModel private[ml] (
165164

166165
private var trainingSummary: Option[KMeansSummary] = None
167166

168-
private[clustering] def setSummary(summary: KMeansSummary): this.type = {
169-
this.trainingSummary = Some(summary)
167+
private[clustering] def setSummary(summary: Option[KMeansSummary]): this.type = {
168+
this.trainingSummary = summary
170169
this
171170
}
172171

@@ -325,7 +324,7 @@ class KMeans @Since("1.5.0") (
325324
val model = copyValues(new KMeansModel(uid, parentModel).setParent(this))
326325
val summary = new KMeansSummary(
327326
model.transform(dataset), $(predictionCol), $(featuresCol), $(k))
328-
model.setSummary(summary)
327+
model.setSummary(Some(summary))
329328
instr.logSuccess(model)
330329
model
331330
}

mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -270,7 +270,7 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val
270270
.setParent(this))
271271
val trainingSummary = new GeneralizedLinearRegressionTrainingSummary(dataset, model,
272272
wlsModel.diagInvAtWA.toArray, 1, getSolver)
273-
return model.setSummary(trainingSummary)
273+
return model.setSummary(Some(trainingSummary))
274274
}
275275

276276
// Fit Generalized Linear Model by iteratively reweighted least squares (IRLS).
@@ -284,7 +284,7 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val
284284
.setParent(this))
285285
val trainingSummary = new GeneralizedLinearRegressionTrainingSummary(dataset, model,
286286
irlsModel.diagInvAtWA.toArray, irlsModel.numIterations, getSolver)
287-
model.setSummary(trainingSummary)
287+
model.setSummary(Some(trainingSummary))
288288
}
289289

290290
@Since("2.0.0")
@@ -761,8 +761,8 @@ class GeneralizedLinearRegressionModel private[ml] (
761761
def hasSummary: Boolean = trainingSummary.nonEmpty
762762

763763
private[regression]
764-
def setSummary(summary: GeneralizedLinearRegressionTrainingSummary): this.type = {
765-
this.trainingSummary = Some(summary)
764+
def setSummary(summary: Option[GeneralizedLinearRegressionTrainingSummary]): this.type = {
765+
this.trainingSummary = summary
766766
this
767767
}
768768

@@ -778,8 +778,7 @@ class GeneralizedLinearRegressionModel private[ml] (
778778
override def copy(extra: ParamMap): GeneralizedLinearRegressionModel = {
779779
val copied = copyValues(new GeneralizedLinearRegressionModel(uid, coefficients, intercept),
780780
extra)
781-
if (trainingSummary.isDefined) copied.setSummary(trainingSummary.get)
782-
copied.setParent(parent)
781+
copied.setSummary(trainingSummary).setParent(parent)
783782
}
784783

785784
/**

mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -223,7 +223,7 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String
223223
model.diagInvAtWA.toArray,
224224
model.objectiveHistory)
225225

226-
return lrModel.setSummary(trainingSummary)
226+
return lrModel.setSummary(Some(trainingSummary))
227227
}
228228

229229
val handlePersistence = dataset.rdd.getStorageLevel == StorageLevel.NONE
@@ -276,7 +276,7 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String
276276
model,
277277
Array(0D),
278278
Array(0D))
279-
return model.setSummary(trainingSummary)
279+
return model.setSummary(Some(trainingSummary))
280280
} else {
281281
require($(regParam) == 0.0, "The standard deviation of the label is zero. " +
282282
"Model cannot be regularized.")
@@ -398,7 +398,7 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String
398398
model,
399399
Array(0D),
400400
objectiveHistory)
401-
model.setSummary(trainingSummary)
401+
model.setSummary(Some(trainingSummary))
402402
}
403403

404404
@Since("1.4.0")
@@ -444,8 +444,9 @@ class LinearRegressionModel private[ml] (
444444
throw new SparkException("No training summary available for this LinearRegressionModel")
445445
}
446446

447-
private[regression] def setSummary(summary: LinearRegressionTrainingSummary): this.type = {
448-
this.trainingSummary = Some(summary)
447+
private[regression]
448+
def setSummary(summary: Option[LinearRegressionTrainingSummary]): this.type = {
449+
this.trainingSummary = summary
449450
this
450451
}
451452

@@ -488,8 +489,7 @@ class LinearRegressionModel private[ml] (
488489
@Since("1.4.0")
489490
override def copy(extra: ParamMap): LinearRegressionModel = {
490491
val newModel = copyValues(new LinearRegressionModel(uid, coefficients, intercept), extra)
491-
if (trainingSummary.isDefined) newModel.setSummary(trainingSummary.get)
492-
newModel.setParent(parent)
492+
newModel.setSummary(trainingSummary).setParent(parent)
493493
}
494494

495495
/**

mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,8 @@ class LogisticRegressionSuite
147147
assert(model.hasSummary)
148148
val copiedModel = model.copy(ParamMap.empty)
149149
assert(copiedModel.hasSummary)
150+
model.setSummary(None)
151+
assert(!model.hasSummary)
150152
}
151153

152154
test("empty probabilityCol") {

mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,9 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultR
123123
assert(clusterSizes.length === k)
124124
assert(clusterSizes.sum === numRows)
125125
assert(clusterSizes.forall(_ >= 0))
126+
127+
model.setSummary(None)
128+
assert(!model.hasSummary)
126129
}
127130

128131
test("KMeansModel transform with non-default feature and prediction cols") {

mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,8 @@ class GeneralizedLinearRegressionSuite
197197
assert(model.hasSummary)
198198
val copiedModel = model.copy(ParamMap.empty)
199199
assert(copiedModel.hasSummary)
200+
model.setSummary(None)
201+
assert(!model.hasSummary)
200202

201203
assert(model.getFeaturesCol === "features")
202204
assert(model.getPredictionCol === "prediction")

mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,8 @@ class LinearRegressionSuite
146146
assert(model.hasSummary)
147147
val copiedModel = model.copy(ParamMap.empty)
148148
assert(copiedModel.hasSummary)
149+
model.setSummary(None)
150+
assert(!model.hasSummary)
149151

150152
model.transform(datasetWithDenseFeature)
151153
.select("label", "prediction")

0 commit comments

Comments
 (0)