3333from pyspark .mllib .util import Saveable , Loader , inherit_doc
3434from pyspark .streaming import DStream
3535
36- __all__ = ['KMeansModel' , 'KMeans' , 'GaussianMixtureModel' , 'GaussianMixture' ]
36+ __all__ = ['KMeansModel' , 'KMeans' , 'GaussianMixtureModel' , 'GaussianMixture' ,
37+ 'StreamingKMeans' , 'StreamingKMeansModel' ]
3738
3839
3940@inherit_doc
@@ -273,38 +274,45 @@ def train(cls, rdd, k, convergenceTol=1e-3, maxIterations=100, seed=None, initia
273274class StreamingKMeansModel (KMeansModel ):
274275 """
275276 .. note:: Experimental
277+
276278 Clustering model which can perform an online update of the centroids.
277279
278280 The update formula for each centroid is given by
279- c_t+1 = [(c_t * n_t * a) + (x_t * m_t)] / [n_t + m_t]
280- n_t+1 = n_t * a + m_t
281+
282+ * c_t+1 = ((c_t * n_t * a) + (x_t * m_t)) / (n_t + m_t)
283+ * n_t+1 = n_t * a + m_t
281284
282285 where
283- c_t: Centroid at the n_th iteration.
284- n_t: Number of samples (or) weights associated with the centroid
286+
287+ * c_t: Centroid at the n_th iteration.
288+ * n_t: Number of samples (or) weights associated with the centroid
285289 at the n_th iteration.
286- x_t: Centroid of the new data closest to c_t.
287- m_t: Number of samples (or) weights of the new data closest to c_t
288- c_t+1: New centroid.
289- n_t+1: New number of weights.
290- a: Decay Factor, which gives the forgetfulness.
290+ * x_t: Centroid of the new data closest to c_t.
291+ * m_t: Number of samples (or) weights of the new data closest to c_t
292+ * c_t+1: New centroid.
293+ * n_t+1: New number of weights.
294+ * a: Decay Factor, which gives the forgetfulness.
291295
292296 Note that if a is set to 1, it is the weighted mean of the previous
293297 and new data. If it set to zero, the old centroids are completely
294298 forgotten.
295299
296- >>> initCenters, initWeights = [[0.0, 0.0], [1.0, 1.0]], [1.0, 1.0]
300+ :param clusterCenters: Initial cluster centers.
301+ :param clusterWeights: List of weights assigned to each cluster.
302+
303+ >>> initCenters = [[0.0, 0.0], [1.0, 1.0]]
304+ >>> initWeights = [1.0, 1.0]
297305 >>> stkm = StreamingKMeansModel(initCenters, initWeights)
298306 >>> data = sc.parallelize([[-0.1, -0.1], [0.1, 0.1],
299307 ... [0.9, 0.9], [1.1, 1.1]])
300308 >>> stkm = stkm.update(data, 1.0, u"batches")
301309 >>> stkm.centers
302310 array([[ 0., 0.],
303311 [ 1., 1.]])
304- >>> stkm.predict([-0.1, -0.1]) == stkm.predict([0.1, 0.1]) == 0
305- True
306- >>> stkm.predict([0.9, 0.9]) == stkm.predict([1.1, 1.1]) == 1
307- True
312+ >>> stkm.predict([-0.1, -0.1])
313+ 0
314+ >>> stkm.predict([0.9, 0.9])
315+ 1
308316 >>> stkm.clusterWeights
309317 [3.0, 3.0]
310318 >>> decayFactor = 0.0
@@ -319,17 +327,14 @@ class StreamingKMeansModel(KMeansModel):
319327 0
320328 >>> stkm.predict([1.5, 1.5])
321329 1
322-
323- :param clusterCenters: Initial cluster centers.
324- :param clusterWeights: List of weights assigned to each cluster.
325330 """
326331 def __init__ (self , clusterCenters , clusterWeights ):
327332 super (StreamingKMeansModel , self ).__init__ (centers = clusterCenters )
328333 self ._clusterWeights = list (clusterWeights )
329334
330335 @property
331336 def clusterWeights (self ):
332- """Convenience method to return the cluster weights."""
337+ """Return the cluster weights."""
333338 return self ._clusterWeights
334339
335340 @ignore_unicode_prefix
@@ -338,13 +343,12 @@ def update(self, data, decayFactor, timeUnit):
338343
339344 :param data: Should be a RDD that represents the new data.
340345 :param decayFactor: forgetfulness of the previous centroids.
341- :param timeUnit: Can be "batches" or "points"
342-
343- If points, then the decay factor is raised to the power of
344- number of new points and if batches, it is used as it is.
346+ :param timeUnit: Can be "batches" or "points". If points, then the
347+ decay factor is raised to the power of number of new
348+ points and if batches, it is used as it is.
345349 """
346350 if not isinstance (data , RDD ):
347- raise TypeError ("data should be of a RDD, got %s." % type (data ))
351+ raise TypeError ("Data should be of an RDD, got %s." % type (data ))
348352 data = data .map (_convert_to_vector )
349353 decayFactor = float (decayFactor )
350354 if timeUnit not in ["batches" , "points" ]:
@@ -363,13 +367,15 @@ class StreamingKMeans(object):
363367 """
364368 .. note:: Experimental
365369
366- Provides methods to set k, decayFactor, timeUnit to train and
367- predict the incoming data
370+ Provides methods to set k, decayFactor, timeUnit to configure the
371+ KMeans algorithm for fitting and predicting on incoming dstreams.
372+ More details on how the centroids are updated are provided under the
373+ docs of StreamingKMeansModel.
368374
369- :param k: int, number of clusters
375+ :param k: int, number of clusters
370376 :param decayFactor: float, forgetfulness of the previous centroids.
371- :param timeUnit: can be "batches" or "points". If points, then the
372- decayfactor is raised to the power of no. of new points.
377+ :param timeUnit: can be "batches" or "points". If points, then the
378+ decayfactor is raised to the power of no. of new points.
373379 """
374380 def __init__ (self , k = 2 , decayFactor = 1.0 , timeUnit = "batches" ):
375381 self ._k = k
@@ -406,14 +412,17 @@ def setDecayFactor(self, decayFactor):
406412
407413 def setHalfLife (self , halfLife , timeUnit ):
408414 """
409- Set number of instances after which the centroids at
410- has 0.5 weightage
415+ Set number of batches after which the centroids of that
416+ particular batch has half the weightage.
411417 """
412418 self ._timeUnit = timeUnit
413419 self ._decayFactor = exp (log (0.5 ) / halfLife )
414420 return self
415421
416422 def setInitialCenters (self , centers , weights ):
423+ """
424+ Set initial centers. Should be set before calling trainOn.
425+ """
417426 self ._model = StreamingKMeansModel (centers , weights )
418427 return self
419428
0 commit comments