Skip to content

Commit 0832530

Browse files
zero323jkbradley
authored andcommitted
[SPARK-12006][ML][PYTHON] Fix GMM failure if initialModel is not None
If initial model passed to GMM is not empty it causes net.razorvine.pickle.PickleException. It can be fixed by converting initialModel.weights to list. Author: zero323 <[email protected]> Closes #10644 from zero323/SPARK-12006. (cherry picked from commit 592f649) Signed-off-by: Joseph K. Bradley <[email protected]>
1 parent d491464 commit 0832530

File tree

2 files changed

+13
-1
lines changed

2 files changed

+13
-1
lines changed

python/pyspark/mllib/clustering.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -255,7 +255,7 @@ def train(cls, rdd, k, convergenceTol=1e-3, maxIterations=100, seed=None, initia
255255
if initialModel.k != k:
256256
raise Exception("Mismatched cluster count, initialModel.k = %s, however k = %s"
257257
% (initialModel.k, k))
258-
initialModelWeights = initialModel.weights
258+
initialModelWeights = list(initialModel.weights)
259259
initialModelMu = [initialModel.gaussians[i].mu for i in range(initialModel.k)]
260260
initialModelSigma = [initialModel.gaussians[i].sigma for i in range(initialModel.k)]
261261
weight, mu, sigma = callMLlibFunc("trainGaussianMixture", rdd.map(_convert_to_vector), k,

python/pyspark/mllib/tests.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -310,6 +310,18 @@ def test_gmm_deterministic(self):
310310
for c1, c2 in zip(clusters1.weights, clusters2.weights):
311311
self.assertEquals(round(c1, 7), round(c2, 7))
312312

313+
def test_gmm_with_initial_model(self):
314+
from pyspark.mllib.clustering import GaussianMixture
315+
data = self.sc.parallelize([
316+
(-10, -5), (-9, -4), (10, 5), (9, 4)
317+
])
318+
319+
gmm1 = GaussianMixture.train(data, 2, convergenceTol=0.001,
320+
maxIterations=10, seed=63)
321+
gmm2 = GaussianMixture.train(data, 2, convergenceTol=0.001,
322+
maxIterations=10, seed=63, initialModel=gmm1)
323+
self.assertAlmostEqual((gmm1.weights - gmm2.weights).sum(), 0.0)
324+
313325
def test_classification(self):
314326
from pyspark.mllib.classification import LogisticRegressionWithSGD, SVMWithSGD, NaiveBayes
315327
from pyspark.mllib.tree import DecisionTree, DecisionTreeModel, RandomForest,\

0 commit comments

Comments
 (0)