Skip to content

Commit eaa8981

Browse files
committed
nit
nit
1 parent 8a49f1c commit eaa8981

File tree

4 files changed

+34
-38
lines changed

4 files changed

+34
-38
lines changed

mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -187,17 +187,15 @@ class LinearSVC @Since("2.2.0") (
187187
val instances = extractInstances(dataset)
188188
.setName("training instances")
189189

190-
val (summarizer, labelSummarizer) = if ($(blockSize) == 1) {
191-
if (dataset.storageLevel == StorageLevel.NONE) {
192-
instances.persist(StorageLevel.MEMORY_AND_DISK)
193-
}
194-
Summarizer.getClassificationSummarizers(instances, $(aggregationDepth))
195-
} else {
196-
// instances will be standardized and converted to blocks, so no need to cache instances.
197-
Summarizer.getClassificationSummarizers(instances, $(aggregationDepth),
198-
Seq("mean", "std", "count", "numNonZeros"))
190+
if (dataset.storageLevel == StorageLevel.NONE && $(blockSize) == 1) {
191+
instances.persist(StorageLevel.MEMORY_AND_DISK)
199192
}
200193

194+
var requestedMetrics = Seq("mean", "std", "count")
195+
if ($(blockSize) != 1) requestedMetrics +:= "numNonZeros"
196+
val (summarizer, labelSummarizer) = Summarizer
197+
.getClassificationSummarizers(instances, $(aggregationDepth), requestedMetrics)
198+
201199
val histogram = labelSummarizer.histogram
202200
val numInvalid = labelSummarizer.countInvalid
203201
val numFeatures = summarizer.mean.size
@@ -316,7 +314,7 @@ class LinearSVC @Since("2.2.0") (
316314
}
317315
val blocks = InstanceBlock.blokify(standardized, $(blockSize))
318316
.persist(StorageLevel.MEMORY_AND_DISK)
319-
.setName(s"training dataset (blockSize=${$(blockSize)})")
317+
.setName(s"training blocks (blockSize=${$(blockSize)})")
320318

321319
val getAggregatorFunc = new BlockHingeAggregator($(fitIntercept))(_)
322320
val costFun = new RDDLossFunction(blocks, getAggregatorFunc,

mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -517,17 +517,18 @@ class LogisticRegression @Since("1.2.0") (
517517
probabilityCol, regParam, elasticNetParam, standardization, threshold, maxIter, tol,
518518
fitIntercept, blockSize)
519519

520-
val instances = extractInstances(dataset).setName("training instances")
521-
if (handlePersistence) instances.persist(StorageLevel.MEMORY_AND_DISK)
520+
val instances = extractInstances(dataset)
521+
.setName("training instances")
522522

523-
val (summarizer, labelSummarizer) = if ($(blockSize) == 1) {
524-
Summarizer.getClassificationSummarizers(instances, $(aggregationDepth))
525-
} else {
526-
// instances will be standardized and converted to blocks, so no need to cache instances.
527-
Summarizer.getClassificationSummarizers(instances, $(aggregationDepth),
528-
Seq("mean", "std", "count", "numNonZeros"))
523+
if (handlePersistence && $(blockSize) == 1) {
524+
instances.persist(StorageLevel.MEMORY_AND_DISK)
529525
}
530526

527+
var requestedMetrics = Seq("mean", "std", "count")
528+
if ($(blockSize) != 1) requestedMetrics +:= "numNonZeros"
529+
val (summarizer, labelSummarizer) = Summarizer
530+
.getClassificationSummarizers(instances, $(aggregationDepth), requestedMetrics)
531+
531532
val numFeatures = summarizer.mean.size
532533
val histogram = labelSummarizer.histogram
533534
val numInvalid = labelSummarizer.countInvalid
@@ -591,7 +592,7 @@ class LogisticRegression @Since("1.2.0") (
591592
} else {
592593
Vectors.dense(if (numClasses == 2) Double.PositiveInfinity else Double.NegativeInfinity)
593594
}
594-
if (handlePersistence) instances.unpersist()
595+
if (instances.getStorageLevel != StorageLevel.NONE) instances.unpersist()
595596
return createModel(dataset, numClasses, coefMatrix, interceptVec, Array.empty)
596597
}
597598

@@ -650,7 +651,7 @@ class LogisticRegression @Since("1.2.0") (
650651
trainOnBlocks(instances, featuresStd, numClasses, initialCoefWithInterceptMatrix,
651652
regularization, optimizer)
652653
}
653-
if (handlePersistence) instances.unpersist()
654+
if (instances.getStorageLevel != StorageLevel.NONE) instances.unpersist()
654655

655656
if (allCoefficients == null) {
656657
val msg = s"${optimizer.getClass.getName} failed."
@@ -1002,7 +1003,7 @@ class LogisticRegression @Since("1.2.0") (
10021003
}
10031004
val blocks = InstanceBlock.blokify(standardized, $(blockSize))
10041005
.persist(StorageLevel.MEMORY_AND_DISK)
1005-
.setName(s"training dataset (blockSize=${$(blockSize)})")
1006+
.setName(s"training blocks (blockSize=${$(blockSize)})")
10061007

10071008
val getAggregatorFunc = new BlockLogisticAggregator(numFeatures, numClasses, $(fitIntercept),
10081009
checkMultinomial(numClasses))(_)

mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -225,22 +225,21 @@ class AFTSurvivalRegression @Since("1.6.0") (@Since("1.6.0") override val uid: S
225225
instr.logNamedValue("quantileProbabilities.size", $(quantileProbabilities).length)
226226

227227
val instances = extractAFTPoints(dataset)
228+
.setName("training instances")
229+
228230
if ($(blockSize) == 1 && dataset.storageLevel == StorageLevel.NONE) {
229231
instances.persist(StorageLevel.MEMORY_AND_DISK)
230232
}
231233

232-
val requestedMetrics = if ($(blockSize) == 1) {
233-
Seq("mean", "std", "count")
234-
} else {
235-
Seq("mean", "std", "count", "numNonZeros")
236-
}
237-
234+
var requestedMetrics = Seq("mean", "std", "count")
235+
if ($(blockSize) != 1) requestedMetrics +:= "numNonZeros"
238236
val summarizer = instances.treeAggregate(
239237
Summarizer.createSummarizerBuffer(requestedMetrics: _*))(
240238
seqOp = (c: SummarizerBuffer, v: AFTPoint) => c.add(v.features),
241239
combOp = (c1: SummarizerBuffer, c2: SummarizerBuffer) => c1.merge(c2),
242240
depth = $(aggregationDepth)
243241
)
242+
244243
val featuresStd = summarizer.std.toArray
245244
val numFeatures = featuresStd.length
246245
instr.logNumFeatures(numFeatures)
@@ -334,7 +333,7 @@ class AFTSurvivalRegression @Since("1.6.0") (@Since("1.6.0") override val uid: S
334333
}
335334
}
336335
blocks.persist(StorageLevel.MEMORY_AND_DISK)
337-
.setName(s"training dataset (blockSize=${$(blockSize)})")
336+
.setName(s"training blocks (blockSize=${$(blockSize)})")
338337

339338
val getAggregatorFunc = new BlockAFTAggregator($(fitIntercept))(_)
340339
val costFun = new RDDLossFunction(blocks, getAggregatorFunc, None, $(aggregationDepth))

mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -355,17 +355,15 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String
355355
val instances = extractInstances(dataset)
356356
.setName("training instances")
357357

358-
val (featuresSummarizer, ySummarizer) = if ($(blockSize) == 1) {
359-
if (dataset.storageLevel == StorageLevel.NONE) {
360-
instances.persist(StorageLevel.MEMORY_AND_DISK)
361-
}
362-
Summarizer.getRegressionSummarizers(instances, $(aggregationDepth))
363-
} else {
364-
// instances will be standardized and converted to blocks, so no need to cache instances.
365-
Summarizer.getRegressionSummarizers(instances, $(aggregationDepth),
366-
Seq("mean", "std", "count", "numNonZeros"))
358+
if (dataset.storageLevel == StorageLevel.NONE && $(blockSize) == 1) {
359+
instances.persist(StorageLevel.MEMORY_AND_DISK)
367360
}
368361

362+
var requestedMetrics = Seq("mean", "std", "count")
363+
if ($(blockSize) != 1) requestedMetrics +:= "numNonZeros"
364+
val (featuresSummarizer, ySummarizer) = Summarizer
365+
.getRegressionSummarizers(instances, $(aggregationDepth), requestedMetrics)
366+
369367
val yMean = ySummarizer.mean(0)
370368
val rawYStd = ySummarizer.std(0)
371369

@@ -617,7 +615,7 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String
617615
}
618616
val blocks = InstanceBlock.blokify(standardized, $(blockSize))
619617
.persist(StorageLevel.MEMORY_AND_DISK)
620-
.setName(s"training dataset (blockSize=${$(blockSize)})")
618+
.setName(s"training blocks (blockSize=${$(blockSize)})")
621619

622620
val costFun = $(loss) match {
623621
case SquaredError =>

0 commit comments

Comments
 (0)