Skip to content

Commit aedbbaa

Browse files
committed
[SPARK-6053][MLLIB] support save/load in PySpark's ALS
A simple wrapper to save/load `MatrixFactorizationModel` in Python. jkbradley Author: Xiangrui Meng <[email protected]> Closes apache#4811 from mengxr/SPARK-5991 and squashes the following commits: f135dac [Xiangrui Meng] update save doc 57e5200 [Xiangrui Meng] address comments 06140a4 [Xiangrui Meng] Merge remote-tracking branch 'apache/master' into SPARK-5991 282ec8d [Xiangrui Meng] support save/load in PySpark's ALS
1 parent fd8d283 commit aedbbaa

File tree

4 files changed

+82
-6
lines changed

4 files changed

+82
-6
lines changed

docs/mllib-collaborative-filtering.md

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -200,10 +200,8 @@ In the following example we load rating data. Each row consists of a user, a pro
200200
We use the default ALS.train() method which assumes ratings are explicit. We evaluate the
201201
recommendation by measuring the Mean Squared Error of rating prediction.
202202

203-
Note that the Python API does not yet support model save/load but will in the future.
204-
205203
{% highlight python %}
206-
from pyspark.mllib.recommendation import ALS, Rating
204+
from pyspark.mllib.recommendation import ALS, MatrixFactorizationModel, Rating
207205

208206
# Load and parse the data
209207
data = sc.textFile("data/mllib/als/test.data")
@@ -220,6 +218,10 @@ predictions = model.predictAll(testdata).map(lambda r: ((r[0], r[1]), r[2]))
220218
ratesAndPreds = ratings.map(lambda r: ((r[0], r[1]), r[2])).join(predictions)
221219
MSE = ratesAndPreds.map(lambda r: (r[1][0] - r[1][1])**2).reduce(lambda x, y: x + y) / ratesAndPreds.count()
222220
print("Mean Squared Error = " + str(MSE))
221+
222+
# Save and load model
223+
model.save(sc, "myModelPath")
224+
sameModel = MatrixFactorizationModel.load(sc, "myModelPath")
223225
{% endhighlight %}
224226

225227
If the rating matrix is derived from other source of information (i.e., it is inferred from other

mllib/src/main/scala/org/apache/spark/mllib/util/modelSaveLoad.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ trait Saveable {
4848
*
4949
* @param sc Spark context used to save model data.
5050
* @param path Path specifying the directory in which to save this model.
51-
* This directory and any intermediate directory will be created if needed.
51+
* If the directory already exists, this method throws an exception.
5252
*/
5353
def save(sc: SparkContext, path: String): Unit
5454

python/pyspark/mllib/recommendation.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,8 @@
1919

2020
from pyspark import SparkContext
2121
from pyspark.rdd import RDD
22-
from pyspark.mllib.common import JavaModelWrapper, callMLlibFunc
22+
from pyspark.mllib.common import JavaModelWrapper, callMLlibFunc, inherit_doc
23+
from pyspark.mllib.util import Saveable, JavaLoader
2324

2425
__all__ = ['MatrixFactorizationModel', 'ALS', 'Rating']
2526

@@ -39,7 +40,8 @@ def __reduce__(self):
3940
return Rating, (int(self.user), int(self.product), float(self.rating))
4041

4142

42-
class MatrixFactorizationModel(JavaModelWrapper):
43+
@inherit_doc
44+
class MatrixFactorizationModel(JavaModelWrapper, Saveable, JavaLoader):
4345

4446
"""A matrix factorisation model trained by regularized alternating
4547
least-squares.
@@ -81,6 +83,17 @@ class MatrixFactorizationModel(JavaModelWrapper):
8183
>>> model = ALS.trainImplicit(ratings, 1, nonnegative=True, seed=10)
8284
>>> model.predict(2,2)
8385
0.43...
86+
87+
>>> import os, tempfile
88+
>>> path = tempfile.mkdtemp()
89+
>>> model.save(sc, path)
90+
>>> sameModel = MatrixFactorizationModel.load(sc, path)
91+
>>> sameModel.predict(2,2)
92+
0.43...
93+
>>> try:
94+
... os.removedirs(path)
95+
... except:
96+
... pass
8497
"""
8598
def predict(self, user, product):
8699
return self._java_model.predict(int(user), int(product))
@@ -98,6 +111,9 @@ def userFeatures(self):
98111
def productFeatures(self):
99112
return self.call("getProductFeatures")
100113

114+
def save(self, sc, path):
115+
self.call("save", sc._jsc.sc(), path)
116+
101117

102118
class ALS(object):
103119

python/pyspark/mllib/util.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,64 @@ def loadLabeledPoints(sc, path, minPartitions=None):
168168
return callMLlibFunc("loadLabeledPoints", sc, path, minPartitions)
169169

170170

171+
class Saveable(object):
172+
"""
173+
Mixin for models and transformers which may be saved as files.
174+
"""
175+
176+
def save(self, sc, path):
177+
"""
178+
Save this model to the given path.
179+
180+
This saves:
181+
* human-readable (JSON) model metadata to path/metadata/
182+
* Parquet formatted data to path/data/
183+
184+
The model may be loaded using py:meth:`Loader.load`.
185+
186+
:param sc: Spark context used to save model data.
187+
:param path: Path specifying the directory in which to save
188+
this model. If the directory already exists,
189+
this method throws an exception.
190+
"""
191+
raise NotImplementedError
192+
193+
194+
class Loader(object):
195+
"""
196+
Mixin for classes which can load saved models from files.
197+
"""
198+
199+
@classmethod
200+
def load(cls, sc, path):
201+
"""
202+
Load a model from the given path. The model should have been
203+
saved using py:meth:`Saveable.save`.
204+
205+
:param sc: Spark context used for loading model files.
206+
:param path: Path specifying the directory to which the model
207+
was saved.
208+
:return: model instance
209+
"""
210+
raise NotImplemented
211+
212+
213+
class JavaLoader(Loader):
214+
"""
215+
Mixin for classes which can load saved models using its Scala
216+
implementation.
217+
"""
218+
219+
@classmethod
220+
def load(cls, sc, path):
221+
java_package = cls.__module__.replace("pyspark", "org.apache.spark")
222+
java_class = ".".join([java_package, cls.__name__])
223+
java_obj = sc._jvm
224+
for name in java_class.split("."):
225+
java_obj = getattr(java_obj, name)
226+
return cls(java_obj.load(sc._jsc.sc(), path))
227+
228+
171229
def _test():
172230
import doctest
173231
from pyspark.context import SparkContext

0 commit comments

Comments
 (0)