@@ -142,6 +142,7 @@ class GaussianMixtureModel(object):
142142
143143 """A clustering model derived from the Gaussian Mixture Model method.
144144
145+ >>> from pyspark.mllib.linalg import Vectors, DenseMatrix
145146 >>> clusterdata_1 = sc.parallelize(array([-0.1,-0.05,-0.01,-0.1,
146147 ... 0.9,0.8,0.75,0.935,
147148 ... -0.83,-0.68,-0.91,-0.76 ]).reshape(6, 2))
@@ -154,24 +155,51 @@ class GaussianMixtureModel(object):
154155 True
155156 >>> labels[4]==labels[5]
156157 True
157- >>> clusterdata_2 = sc.parallelize(array([-5.1971, -2.5359, -3.8220,
158- ... -5.2211, -5.0602, 4.7118,
159- ... 6.8989, 3.4592, 4.6322,
160- ... 5.7048, 4.6567, 5.5026,
161- ... 4.5605, 5.2043, 6.2734]).reshape(5, 3))
158+ >>> data = array([-5.1971, -2.5359, -3.8220,
159+ ... -5.2211, -5.0602, 4.7118,
160+ ... 6.8989, 3.4592, 4.6322,
161+ ... 5.7048, 4.6567, 5.5026,
162+ ... 4.5605, 5.2043, 6.2734])
163+ >>> clusterdata_2 = sc.parallelize(data.reshape(5,3))
162164 >>> model = GaussianMixture.train(clusterdata_2, 2, convergenceTol=0.0001,
163165 ... maxIterations=150, seed=10)
164166 >>> labels = model.predict(clusterdata_2).collect()
165167 >>> labels[0]==labels[1]==labels[2]
166168 True
167169 >>> labels[3]==labels[4]
168170 True
171+ >>> clusterdata_3 = sc.parallelize(data.reshape(15, 1))
172+ >>> im = GaussianMixtureModel([0.5, 0.5],
173+ ... [MultivariateGaussian(Vectors.dense([-1.0]), DenseMatrix(1, 1, [1.0])),
174+ ... MultivariateGaussian(Vectors.dense([1.0]), DenseMatrix(1, 1, [1.0]))])
175+ >>> model = GaussianMixture.train(clusterdata_3, 2, initialModel=im)
169176 """
170177
171178 def __init__ (self , weights , gaussians ):
172- self .weights = weights
173- self .gaussians = gaussians
174- self .k = len (self .weights )
179+ self ._weights = weights
180+ self ._gaussians = gaussians
181+ self ._k = len (self ._weights )
182+
183+ @property
184+ def weights (self ):
185+ """
186+ Weights for each Gaussian distribution in the mixture, where weights[i] is
187+ the weight for Gaussian i, and weights.sum == 1.
188+ """
189+ return self ._weights
190+
191+ @property
192+ def gaussians (self ):
193+ """
194+ Array of MultivariateGaussian where gaussians[i] represents
195+ the Multivariate Gaussian (Normal) Distribution for Gaussian i.
196+ """
197+ return self ._gaussians
198+
199+ @property
200+ def k (self ):
201+ """Number of gaussians in mixture."""
202+ return self ._k
175203
176204 def predict (self , x ):
177205 """
@@ -193,9 +221,9 @@ def predictSoft(self, x):
193221 :return: membership_matrix. RDD of array of double values.
194222 """
195223 if isinstance (x , RDD ):
196- means , sigmas = zip (* [(g .mu , g .sigma ) for g in self .gaussians ])
224+ means , sigmas = zip (* [(g .mu , g .sigma ) for g in self ._gaussians ])
197225 membership_matrix = callMLlibFunc ("predictSoftGMM" , x .map (_convert_to_vector ),
198- _convert_to_vector (self .weights ), means , sigmas )
226+ _convert_to_vector (self ._weights ), means , sigmas )
199227 return membership_matrix .map (lambda x : pyarray .array ('d' , x ))
200228
201229
@@ -208,13 +236,24 @@ class GaussianMixture(object):
208236 :param convergenceTol: Threshold value to check the convergence criteria. Defaults to 1e-3
209237 :param maxIterations: Number of iterations. Default to 100
210238 :param seed: Random Seed
239+ :param initialModel: GaussianMixtureModel for initializing learning
211240 """
212241 @classmethod
213- def train (cls , rdd , k , convergenceTol = 1e-3 , maxIterations = 100 , seed = None ):
242+ def train (cls , rdd , k , convergenceTol = 1e-3 , maxIterations = 100 , seed = None , initialModel = None ):
214243 """Train a Gaussian Mixture clustering model."""
215- weight , mu , sigma = callMLlibFunc ("trainGaussianMixture" ,
216- rdd .map (_convert_to_vector ), k ,
217- convergenceTol , maxIterations , seed )
244+ initialModelWeights = None
245+ initialModelMu = None
246+ initialModelSigma = None
247+ if initialModel is not None :
248+ if initialModel .k != k :
249+ raise Exception ("Mismatched cluster count, initialModel.k = %s, however k = %s"
250+ % (initialModel .k , k ))
251+ initialModelWeights = initialModel .weights
252+ initialModelMu = [initialModel .gaussians [i ].mu for i in range (initialModel .k )]
253+ initialModelSigma = [initialModel .gaussians [i ].sigma for i in range (initialModel .k )]
254+ weight , mu , sigma = callMLlibFunc ("trainGaussianMixture" , rdd .map (_convert_to_vector ), k ,
255+ convergenceTol , maxIterations , seed , initialModelWeights ,
256+ initialModelMu , initialModelSigma )
218257 mvg_obj = [MultivariateGaussian (mu [i ], sigma [i ]) for i in range (k )]
219258 return GaussianMixtureModel (weight , mvg_obj )
220259
0 commit comments