Skip to content

Commit 2bd02e2

Browse files
viiryasrowen
authored andcommitted
[SPARK-28866][ML] Persist item factors RDD when checkpointing in ALS
### What changes were proposed in this pull request? In ALS ML implementation, for non-implicit case, we checkpoint the RDD of item factors, between intervals. Before checkpointing (.checkpoint()) and materializing (.count()) RDD, this RDD was not persisted. It causes recomputation. In an experiment, there is performance difference between persisting and no persisting before checkpointing the RDD. The performance difference is not big, but this change is not big too. The actual performance difference varies depending the interval of checkpoint, training dataset, etc. ### Why are the changes needed? Persisting the RDD before checkpointing the RDD of item factors can avoid recomputation. ### Does this PR introduce any user-facing change? No ### How was this patch tested? Manual check RDD recomputation or not. Taking 30% MovieLens 20M Dataset as training dataset. Setting checkpoint dir for SparkContext. Fitting an ALS model like: ```scala val als = new ALS() .setMaxIter(100) .setCheckpointInterval(5) .setRegParam(0.01) .setUserCol("userId") .setItemCol("movieId") .setRatingCol("rating") val t0 = System.currentTimeMillis() val model = als.fit(training) val t1 = System.currentTimeMillis() ``` Before this patch: 65.386 s After this patch: 61.022 s Closes apache#25576 from viirya/persist-item-factors. Authored-by: Liang-Chi Hsieh <[email protected]> Signed-off-by: Sean Owen <[email protected]>
1 parent 8279693 commit 2bd02e2

File tree

1 file changed

+6
-1
lines changed
  • mllib/src/main/scala/org/apache/spark/ml/recommendation

1 file changed

+6
-1
lines changed

mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -990,16 +990,21 @@ object ALS extends DefaultParamsReadable[ALS] with Logging {
990990
previousUserFactors.unpersist()
991991
}
992992
} else {
993+
var previousCachedItemFactors: Option[RDD[(Int, FactorBlock)]] = None
993994
for (iter <- 0 until maxIter) {
994995
itemFactors = computeFactors(userFactors, userOutBlocks, itemInBlocks, rank, regParam,
995996
userLocalIndexEncoder, solver = solver)
996997
if (shouldCheckpoint(iter)) {
998+
itemFactors.setName(s"itemFactors-$iter").persist(intermediateRDDStorageLevel)
997999
val deps = itemFactors.dependencies
9981000
itemFactors.checkpoint()
9991001
itemFactors.count() // checkpoint item factors and cut lineage
10001002
ALS.cleanShuffleDependencies(sc, deps)
10011003
deletePreviousCheckpointFile()
1004+
1005+
previousCachedItemFactors.foreach(_.unpersist())
10021006
previousCheckpointFile = itemFactors.getCheckpointFile
1007+
previousCachedItemFactors = Option(itemFactors)
10031008
}
10041009
userFactors = computeFactors(itemFactors, itemOutBlocks, userInBlocks, rank, regParam,
10051010
itemLocalIndexEncoder, solver = solver)
@@ -1029,8 +1034,8 @@ object ALS extends DefaultParamsReadable[ALS] with Logging {
10291034
.persist(finalRDDStorageLevel)
10301035
if (finalRDDStorageLevel != StorageLevel.NONE) {
10311036
userIdAndFactors.count()
1032-
itemFactors.unpersist()
10331037
itemIdAndFactors.count()
1038+
itemFactors.unpersist()
10341039
userInBlocks.unpersist()
10351040
userOutBlocks.unpersist()
10361041
itemInBlocks.unpersist()

0 commit comments

Comments
 (0)