Skip to content

Commit e009ee1

Browse files
committed
address comment issues
1 parent a33c4ea commit e009ee1

File tree

6 files changed

+44
-45
lines changed

6 files changed

+44
-45
lines changed

mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,8 @@ private[shared] object SharedParamsCodeGen {
8383
ParamDesc[String]("solver", "the solver algorithm for optimization", finalFields = false),
8484
ParamDesc[Int]("aggregationDepth", "suggested depth for treeAggregate (>= 2)", Some("2"),
8585
isValid = "ParamValidators.gtEq(2)", isExpertParam = true),
86-
ParamDesc[Boolean]("collectSubModels", "whether to collect sub models when tuning fitting",
86+
ParamDesc[Boolean]("collectSubModels", "whether to collect a list of sub-models trained " +
87+
"during tuning",
8788
Some("false"), isExpertParam = true)
8889
)
8990

mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -409,10 +409,10 @@ private[ml] trait HasAggregationDepth extends Params {
409409
private[ml] trait HasCollectSubModels extends Params {
410410

411411
/**
412-
* Param for whether to collect sub models when tuning fitting.
412+
* Param for whether to collect a list of sub-models trained during tuning.
413413
* @group expertParam
414414
*/
415-
final val collectSubModels: BooleanParam = new BooleanParam(this, "collectSubModels", "whether to collect sub models when tuning fitting")
415+
final val collectSubModels: BooleanParam = new BooleanParam(this, "collectSubModels", "whether to collect a list of sub-models trained during tuning")
416416

417417
setDefault(collectSubModels, false)
418418

mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala

Lines changed: 14 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -125,9 +125,9 @@ class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String)
125125

126126
val collectSubModelsParam = $(collectSubModels)
127127

128-
var subModels: Array[Array[Model[_]]] = if (collectSubModelsParam) {
129-
Array.fill($(numFolds))(Array.fill[Model[_]](epm.length)(null))
130-
} else null
128+
var subModels: Option[Array[Array[Model[_]]]] = if (collectSubModelsParam) {
129+
Some(Array.fill($(numFolds))(Array.fill[Model[_]](epm.length)(null)))
130+
} else None
131131

132132
// Compute metrics for each model over each split
133133
val splits = MLUtils.kFold(dataset.toDF.rdd, $(numFolds), $(seed))
@@ -142,7 +142,7 @@ class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String)
142142
val model = est.fit(trainingDataset, paramMap).asInstanceOf[Model[_]]
143143

144144
if (collectSubModelsParam) {
145-
subModels(splitIndex)(paramIndex) = model
145+
subModels.get(splitIndex)(paramIndex) = model
146146
}
147147
model
148148
} (executionContext)
@@ -253,7 +253,7 @@ class CrossValidatorModel private[ml] (
253253
@Since("1.4.0") override val uid: String,
254254
@Since("1.2.0") val bestModel: Model[_],
255255
@Since("1.5.0") val avgMetrics: Array[Double],
256-
@Since("2.3.0") val subModels: Array[Array[Model[_]]])
256+
@Since("2.3.0") val subModels: Option[Array[Array[Model[_]]]])
257257
extends Model[CrossValidatorModel] with CrossValidatorParams with MLWritable {
258258

259259
/** A Python-friendly auxiliary constructor. */
@@ -300,19 +300,18 @@ class CrossValidatorModel private[ml] (
300300
@Since("1.6.0")
301301
object CrossValidatorModel extends MLReadable[CrossValidatorModel] {
302302

303-
private[CrossValidatorModel] def copySubModels(subModels: Array[Array[Model[_]]]) = {
304-
var copiedSubModels: Array[Array[Model[_]]] = null
305-
if (subModels != null) {
303+
private[CrossValidatorModel] def copySubModels(subModels: Option[Array[Array[Model[_]]]]) = {
304+
subModels.map { subModels =>
306305
val numFolds = subModels.length
307306
val numParamMaps = subModels(0).length
308-
copiedSubModels = Array.fill(numFolds)(Array.fill[Model[_]](numParamMaps)(null))
307+
val copiedSubModels = Array.fill(numFolds)(Array.fill[Model[_]](numParamMaps)(null))
309308
for (i <- 0 until numFolds) {
310309
for (j <- 0 until numParamMaps) {
311310
copiedSubModels(i)(j) = subModels(i)(j).copy(ParamMap.empty).asInstanceOf[Model[_]]
312311
}
313312
}
313+
copiedSubModels
314314
}
315-
copiedSubModels
316315
}
317316

318317
@Since("1.6.0")
@@ -345,13 +344,13 @@ object CrossValidatorModel extends MLReadable[CrossValidatorModel] {
345344
val bestModelPath = new Path(path, "bestModel").toString
346345
instance.bestModel.asInstanceOf[MLWritable].save(bestModelPath)
347346
if (shouldPersistSubModels) {
348-
require(instance.subModels != null, "Cannot get sub models to persist.")
347+
require(instance.subModels.isDefined, "Cannot get sub models to persist.")
349348
val subModelsPath = new Path(path, "subModels")
350349
for (splitIndex <- 0 until instance.getNumFolds) {
351350
val splitPath = new Path(subModelsPath, splitIndex.toString)
352351
for (paramIndex <- 0 until instance.getEstimatorParamMaps.length) {
353352
val modelPath = new Path(splitPath, paramIndex.toString).toString
354-
instance.subModels(splitIndex)(paramIndex).asInstanceOf[MLWritable].save(modelPath)
353+
instance.subModels.get(splitIndex)(paramIndex).asInstanceOf[MLWritable].save(modelPath)
355354
}
356355
}
357356
}
@@ -374,7 +373,7 @@ object CrossValidatorModel extends MLReadable[CrossValidatorModel] {
374373
val avgMetrics = (metadata.metadata \ "avgMetrics").extract[Seq[Double]].toArray
375374
val shouldPersistSubModels = (metadata.metadata \ "shouldPersistSubModels").extract[Boolean]
376375

377-
val subModels: Array[Array[Model[_]]] = if (shouldPersistSubModels) {
376+
val subModels: Option[Array[Array[Model[_]]]] = if (shouldPersistSubModels) {
378377
val subModelsPath = new Path(path, "subModels")
379378
val _subModels = Array.fill(numFolds)(Array.fill[Model[_]](
380379
estimatorParamMaps.length)(null))
@@ -386,8 +385,8 @@ object CrossValidatorModel extends MLReadable[CrossValidatorModel] {
386385
DefaultParamsReader.loadParamsInstance(modelPath, sc)
387386
}
388387
}
389-
_subModels
390-
} else null
388+
Some(_subModels)
389+
} else None
391390

392391
val model = new CrossValidatorModel(metadata.uid, bestModel, avgMetrics, subModels)
393392
model.set(model.estimator, estimator)

mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala

Lines changed: 14 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -128,9 +128,9 @@ class TrainValidationSplit @Since("1.5.0") (@Since("1.5.0") override val uid: St
128128

129129
val collectSubModelsParam = $(collectSubModels)
130130

131-
var subModels: Array[Model[_]] = if (collectSubModelsParam) {
132-
Array.fill[Model[_]](epm.length)(null)
133-
} else null
131+
var subModels: Option[Array[Model[_]]] = if (collectSubModelsParam) {
132+
Some(Array.fill[Model[_]](epm.length)(null))
133+
} else None
134134

135135
// Fit models in a Future for training in parallel
136136
logDebug(s"Train split with multiple sets of parameters.")
@@ -139,7 +139,7 @@ class TrainValidationSplit @Since("1.5.0") (@Since("1.5.0") override val uid: St
139139
val model = est.fit(trainingDataset, paramMap).asInstanceOf[Model[_]]
140140

141141
if (collectSubModelsParam) {
142-
subModels(paramIndex) = model
142+
subModels.get(paramIndex) = model
143143
}
144144
model
145145
} (executionContext)
@@ -246,7 +246,7 @@ class TrainValidationSplitModel private[ml] (
246246
@Since("1.5.0") override val uid: String,
247247
@Since("1.5.0") val bestModel: Model[_],
248248
@Since("1.5.0") val validationMetrics: Array[Double],
249-
@Since("2.3.0") val subModels: Array[Model[_]])
249+
@Since("2.3.0") val subModels: Option[Array[Model[_]]])
250250
extends Model[TrainValidationSplitModel] with TrainValidationSplitParams with MLWritable {
251251

252252
/** A Python-friendly auxiliary constructor. */
@@ -293,16 +293,15 @@ class TrainValidationSplitModel private[ml] (
293293
@Since("2.0.0")
294294
object TrainValidationSplitModel extends MLReadable[TrainValidationSplitModel] {
295295

296-
private[TrainValidationSplitModel] def copySubModels(subModels: Array[Model[_]]) = {
297-
var copiedSubModels: Array[Model[_]] = null
298-
if (subModels != null) {
296+
private[TrainValidationSplitModel] def copySubModels(subModels: Option[Array[Model[_]]]) = {
297+
subModels.map { subModels =>
299298
val numParamMaps = subModels.length
300-
copiedSubModels = Array.fill[Model[_]](numParamMaps)(null)
299+
val copiedSubModels = Array.fill[Model[_]](numParamMaps)(null)
301300
for (i <- 0 until numParamMaps) {
302301
copiedSubModels(i) = subModels(i).copy(ParamMap.empty).asInstanceOf[Model[_]]
303302
}
303+
copiedSubModels
304304
}
305-
copiedSubModels
306305
}
307306

308307
@Since("2.0.0")
@@ -335,11 +334,11 @@ object TrainValidationSplitModel extends MLReadable[TrainValidationSplitModel] {
335334
val bestModelPath = new Path(path, "bestModel").toString
336335
instance.bestModel.asInstanceOf[MLWritable].save(bestModelPath)
337336
if (shouldPersistSubModels) {
338-
require(instance.subModels != null, "Cannot get sub models to persist.")
337+
require(instance.subModels.isDefined, "Cannot get sub models to persist.")
339338
val subModelsPath = new Path(path, "subModels")
340339
for (paramIndex <- 0 until instance.getEstimatorParamMaps.length) {
341340
val modelPath = new Path(subModelsPath, paramIndex.toString).toString
342-
instance.subModels(paramIndex).asInstanceOf[MLWritable].save(modelPath)
341+
instance.subModels.get(paramIndex).asInstanceOf[MLWritable].save(modelPath)
343342
}
344343
}
345344
}
@@ -360,16 +359,16 @@ object TrainValidationSplitModel extends MLReadable[TrainValidationSplitModel] {
360359
val validationMetrics = (metadata.metadata \ "validationMetrics").extract[Seq[Double]].toArray
361360
val shouldPersistSubModels = (metadata.metadata \ "shouldPersistSubModels").extract[Boolean]
362361

363-
val subModels: Array[Model[_]] = if (shouldPersistSubModels) {
362+
val subModels: Option[Array[Model[_]]] = if (shouldPersistSubModels) {
364363
val subModelsPath = new Path(path, "subModels")
365364
val _subModels = Array.fill[Model[_]](estimatorParamMaps.length)(null)
366365
for (paramIndex <- 0 until estimatorParamMaps.length) {
367366
val modelPath = new Path(subModelsPath, paramIndex.toString).toString
368367
_subModels(paramIndex) =
369368
DefaultParamsReader.loadParamsInstance(modelPath, sc)
370369
}
371-
_subModels
372-
} else null
370+
Some(_subModels)
371+
} else None
373372

374373
val model = new TrainValidationSplitModel(metadata.uid, bestModel, validationMetrics,
375374
subModels)

mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -213,24 +213,24 @@ class CrossValidatorSuite
213213

214214
val cvModel = cv.fit(dataset)
215215

216-
assert(cvModel.subModels != null && cvModel.subModels.length == numFolds)
217-
cvModel.subModels.foreach(array => assert(array.length == lrParamMaps.length))
216+
assert(cvModel.subModels.isDefined && cvModel.subModels.get.length == numFolds)
217+
cvModel.subModels.get.foreach(array => assert(array.length == lrParamMaps.length))
218218

219219
val savingPathWithoutSubModels = new File(subPath, "cvModel2").getPath
220220
cvModel.save(savingPathWithoutSubModels)
221221
val cvModel2 = CrossValidatorModel.load(savingPathWithoutSubModels)
222-
assert(cvModel2.subModels === null)
222+
assert(cvModel2.subModels.isEmpty)
223223

224224
val savingPathWithSubModels = new File(subPath, "cvModel3").getPath
225225
cvModel.save(savingPathWithSubModels, persistSubModels = true)
226226
val cvModel3 = CrossValidatorModel.load(savingPathWithSubModels)
227-
assert(cvModel3.subModels != null && cvModel3.subModels.length == numFolds)
228-
cvModel3.subModels.foreach(array => assert(array.length == lrParamMaps.length))
227+
assert(cvModel3.subModels.isDefined && cvModel3.subModels.get.length == numFolds)
228+
cvModel3.subModels.get.foreach(array => assert(array.length == lrParamMaps.length))
229229

230230
for (i <- 0 until numFolds) {
231231
for (j <- 0 until lrParamMaps.length) {
232-
assert(cvModel.subModels(i)(j).asInstanceOf[LogisticRegressionModel].uid ===
233-
cvModel3.subModels(i)(j).asInstanceOf[LogisticRegressionModel].uid)
232+
assert(cvModel.subModels.get(i)(j).asInstanceOf[LogisticRegressionModel].uid ===
233+
cvModel3.subModels.get(i)(j).asInstanceOf[LogisticRegressionModel].uid)
234234
}
235235
}
236236
}

mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -204,21 +204,21 @@ class TrainValidationSplitSuite
204204

205205
val tvsModel = tvs.fit(dataset)
206206

207-
assert(tvsModel.subModels != null && tvsModel.subModels.length == lrParamMaps.length)
207+
assert(tvsModel.subModels.isDefined && tvsModel.subModels.get.length == lrParamMaps.length)
208208

209209
val savingPathWithoutSubModels = new File(subPath, "tvsModel2").getPath
210210
tvsModel.save(savingPathWithoutSubModels)
211211
val tvsModel2 = TrainValidationSplitModel.load(savingPathWithoutSubModels)
212-
assert(tvsModel2.subModels === null)
212+
assert(tvsModel2.subModels.isEmpty)
213213

214214
val savingPathWithSubModels = new File(subPath, "tvsModel3").getPath
215215
tvsModel.save(savingPathWithSubModels, persistSubModels = true)
216216
val tvsModel3 = TrainValidationSplitModel.load(savingPathWithSubModels)
217-
assert(tvsModel3.subModels != null && tvsModel3.subModels.length == lrParamMaps.length)
217+
assert(tvsModel3.subModels.isDefined && tvsModel3.subModels.get.length == lrParamMaps.length)
218218

219219
for (i <- 0 until lrParamMaps.length) {
220-
assert(tvsModel.subModels(i).asInstanceOf[LogisticRegressionModel].uid ===
221-
tvsModel3.subModels(i).asInstanceOf[LogisticRegressionModel].uid)
220+
assert(tvsModel.subModels.get(i).asInstanceOf[LogisticRegressionModel].uid ===
221+
tvsModel3.subModels.get(i).asInstanceOf[LogisticRegressionModel].uid)
222222
}
223223
}
224224

0 commit comments

Comments
 (0)