|
21 | 21 | if sys.version > '3': |
22 | 22 | xrange = range |
23 | 23 |
|
24 | | -from numpy import array |
| 24 | +from math import exp, log |
| 25 | + |
| 26 | +from numpy import array, random, tile |
25 | 27 |
|
26 | | -from pyspark import RDD |
27 | 28 | from pyspark import SparkContext |
| 29 | +from pyspark.rdd import RDD, ignore_unicode_prefix |
28 | 30 | from pyspark.mllib.common import callMLlibFunc, callJavaFunc, _py2java, _java2py |
29 | | -from pyspark.mllib.linalg import SparseVector, _convert_to_vector |
| 31 | +from pyspark.mllib.linalg import SparseVector, _convert_to_vector, DenseVector |
30 | 32 | from pyspark.mllib.stat.distribution import MultivariateGaussian |
31 | 33 | from pyspark.mllib.util import Saveable, Loader, inherit_doc |
| 34 | +from pyspark.streaming import DStream |
32 | 35 |
|
33 | | -__all__ = ['KMeansModel', 'KMeans', 'GaussianMixtureModel', 'GaussianMixture'] |
| 36 | +__all__ = ['KMeansModel', 'KMeans', 'GaussianMixtureModel', 'GaussianMixture', |
| 37 | + 'StreamingKMeans', 'StreamingKMeansModel'] |
34 | 38 |
|
35 | 39 |
|
36 | 40 | @inherit_doc |
@@ -98,6 +102,9 @@ def predict(self, x): |
98 | 102 | """Find the cluster to which x belongs in this model.""" |
99 | 103 | best = 0 |
100 | 104 | best_distance = float("inf") |
| 105 | + if isinstance(x, RDD): |
| 106 | + return x.map(self.predict) |
| 107 | + |
101 | 108 | x = _convert_to_vector(x) |
102 | 109 | for i in xrange(len(self.centers)): |
103 | 110 | distance = x.squared_distance(self.centers[i]) |
@@ -264,6 +271,198 @@ def train(cls, rdd, k, convergenceTol=1e-3, maxIterations=100, seed=None, initia |
264 | 271 | return GaussianMixtureModel(weight, mvg_obj) |
265 | 272 |
|
266 | 273 |
|
| 274 | +class StreamingKMeansModel(KMeansModel): |
| 275 | + """ |
| 276 | + .. note:: Experimental |
| 277 | +
|
| 278 | + Clustering model which can perform an online update of the centroids. |
| 279 | +
|
| 280 | + The update formula for each centroid is given by |
| 281 | +
|
| 282 | + * c_t+1 = ((c_t * n_t * a) + (x_t * m_t)) / (n_t + m_t) |
| 283 | + * n_t+1 = n_t * a + m_t |
| 284 | +
|
| 285 | + where |
| 286 | +
|
| 287 | + * c_t: Centroid at the n_th iteration. |
| 288 | + * n_t: Number of samples (or) weights associated with the centroid |
| 289 | + at the n_th iteration. |
| 290 | + * x_t: Centroid of the new data closest to c_t. |
| 291 | + * m_t: Number of samples (or) weights of the new data closest to c_t |
| 292 | + * c_t+1: New centroid. |
| 293 | + * n_t+1: New number of weights. |
| 294 | + * a: Decay Factor, which gives the forgetfulness. |
| 295 | +
|
| 296 | + Note that if a is set to 1, it is the weighted mean of the previous |
| 297 | + and new data. If it set to zero, the old centroids are completely |
| 298 | + forgotten. |
| 299 | +
|
| 300 | + :param clusterCenters: Initial cluster centers. |
| 301 | + :param clusterWeights: List of weights assigned to each cluster. |
| 302 | +
|
| 303 | + >>> initCenters = [[0.0, 0.0], [1.0, 1.0]] |
| 304 | + >>> initWeights = [1.0, 1.0] |
| 305 | + >>> stkm = StreamingKMeansModel(initCenters, initWeights) |
| 306 | + >>> data = sc.parallelize([[-0.1, -0.1], [0.1, 0.1], |
| 307 | + ... [0.9, 0.9], [1.1, 1.1]]) |
| 308 | + >>> stkm = stkm.update(data, 1.0, u"batches") |
| 309 | + >>> stkm.centers |
| 310 | + array([[ 0., 0.], |
| 311 | + [ 1., 1.]]) |
| 312 | + >>> stkm.predict([-0.1, -0.1]) |
| 313 | + 0 |
| 314 | + >>> stkm.predict([0.9, 0.9]) |
| 315 | + 1 |
| 316 | + >>> stkm.clusterWeights |
| 317 | + [3.0, 3.0] |
| 318 | + >>> decayFactor = 0.0 |
| 319 | + >>> data = sc.parallelize([DenseVector([1.5, 1.5]), DenseVector([0.2, 0.2])]) |
| 320 | + >>> stkm = stkm.update(data, 0.0, u"batches") |
| 321 | + >>> stkm.centers |
| 322 | + array([[ 0.2, 0.2], |
| 323 | + [ 1.5, 1.5]]) |
| 324 | + >>> stkm.clusterWeights |
| 325 | + [1.0, 1.0] |
| 326 | + >>> stkm.predict([0.2, 0.2]) |
| 327 | + 0 |
| 328 | + >>> stkm.predict([1.5, 1.5]) |
| 329 | + 1 |
| 330 | + """ |
| 331 | + def __init__(self, clusterCenters, clusterWeights): |
| 332 | + super(StreamingKMeansModel, self).__init__(centers=clusterCenters) |
| 333 | + self._clusterWeights = list(clusterWeights) |
| 334 | + |
| 335 | + @property |
| 336 | + def clusterWeights(self): |
| 337 | + """Return the cluster weights.""" |
| 338 | + return self._clusterWeights |
| 339 | + |
| 340 | + @ignore_unicode_prefix |
| 341 | + def update(self, data, decayFactor, timeUnit): |
| 342 | + """Update the centroids, according to data |
| 343 | +
|
| 344 | + :param data: Should be a RDD that represents the new data. |
| 345 | + :param decayFactor: forgetfulness of the previous centroids. |
| 346 | + :param timeUnit: Can be "batches" or "points". If points, then the |
| 347 | + decay factor is raised to the power of number of new |
| 348 | + points and if batches, it is used as it is. |
| 349 | + """ |
| 350 | + if not isinstance(data, RDD): |
| 351 | + raise TypeError("Data should be of an RDD, got %s." % type(data)) |
| 352 | + data = data.map(_convert_to_vector) |
| 353 | + decayFactor = float(decayFactor) |
| 354 | + if timeUnit not in ["batches", "points"]: |
| 355 | + raise ValueError( |
| 356 | + "timeUnit should be 'batches' or 'points', got %s." % timeUnit) |
| 357 | + vectorCenters = [_convert_to_vector(center) for center in self.centers] |
| 358 | + updatedModel = callMLlibFunc( |
| 359 | + "updateStreamingKMeansModel", vectorCenters, self._clusterWeights, |
| 360 | + data, decayFactor, timeUnit) |
| 361 | + self.centers = array(updatedModel[0]) |
| 362 | + self._clusterWeights = list(updatedModel[1]) |
| 363 | + return self |
| 364 | + |
| 365 | + |
| 366 | +class StreamingKMeans(object): |
| 367 | + """ |
| 368 | + .. note:: Experimental |
| 369 | +
|
| 370 | + Provides methods to set k, decayFactor, timeUnit to configure the |
| 371 | + KMeans algorithm for fitting and predicting on incoming dstreams. |
| 372 | + More details on how the centroids are updated are provided under the |
| 373 | + docs of StreamingKMeansModel. |
| 374 | +
|
| 375 | + :param k: int, number of clusters |
| 376 | + :param decayFactor: float, forgetfulness of the previous centroids. |
| 377 | + :param timeUnit: can be "batches" or "points". If points, then the |
| 378 | + decayfactor is raised to the power of no. of new points. |
| 379 | + """ |
| 380 | + def __init__(self, k=2, decayFactor=1.0, timeUnit="batches"): |
| 381 | + self._k = k |
| 382 | + self._decayFactor = decayFactor |
| 383 | + if timeUnit not in ["batches", "points"]: |
| 384 | + raise ValueError( |
| 385 | + "timeUnit should be 'batches' or 'points', got %s." % timeUnit) |
| 386 | + self._timeUnit = timeUnit |
| 387 | + self._model = None |
| 388 | + |
| 389 | + def latestModel(self): |
| 390 | + """Return the latest model""" |
| 391 | + return self._model |
| 392 | + |
| 393 | + def _validate(self, dstream): |
| 394 | + if self._model is None: |
| 395 | + raise ValueError( |
| 396 | + "Initial centers should be set either by setInitialCenters " |
| 397 | + "or setRandomCenters.") |
| 398 | + if not isinstance(dstream, DStream): |
| 399 | + raise TypeError( |
| 400 | + "Expected dstream to be of type DStream, " |
| 401 | + "got type %s" % type(dstream)) |
| 402 | + |
| 403 | + def setK(self, k): |
| 404 | + """Set number of clusters.""" |
| 405 | + self._k = k |
| 406 | + return self |
| 407 | + |
| 408 | + def setDecayFactor(self, decayFactor): |
| 409 | + """Set decay factor.""" |
| 410 | + self._decayFactor = decayFactor |
| 411 | + return self |
| 412 | + |
| 413 | + def setHalfLife(self, halfLife, timeUnit): |
| 414 | + """ |
| 415 | + Set number of batches after which the centroids of that |
| 416 | + particular batch has half the weightage. |
| 417 | + """ |
| 418 | + self._timeUnit = timeUnit |
| 419 | + self._decayFactor = exp(log(0.5) / halfLife) |
| 420 | + return self |
| 421 | + |
| 422 | + def setInitialCenters(self, centers, weights): |
| 423 | + """ |
| 424 | + Set initial centers. Should be set before calling trainOn. |
| 425 | + """ |
| 426 | + self._model = StreamingKMeansModel(centers, weights) |
| 427 | + return self |
| 428 | + |
| 429 | + def setRandomCenters(self, dim, weight, seed): |
| 430 | + """ |
| 431 | + Set the initial centres to be random samples from |
| 432 | + a gaussian population with constant weights. |
| 433 | + """ |
| 434 | + rng = random.RandomState(seed) |
| 435 | + clusterCenters = rng.randn(self._k, dim) |
| 436 | + clusterWeights = tile(weight, self._k) |
| 437 | + self._model = StreamingKMeansModel(clusterCenters, clusterWeights) |
| 438 | + return self |
| 439 | + |
| 440 | + def trainOn(self, dstream): |
| 441 | + """Train the model on the incoming dstream.""" |
| 442 | + self._validate(dstream) |
| 443 | + |
| 444 | + def update(rdd): |
| 445 | + self._model.update(rdd, self._decayFactor, self._timeUnit) |
| 446 | + |
| 447 | + dstream.foreachRDD(update) |
| 448 | + |
| 449 | + def predictOn(self, dstream): |
| 450 | + """ |
| 451 | + Make predictions on a dstream. |
| 452 | + Returns a transformed dstream object |
| 453 | + """ |
| 454 | + self._validate(dstream) |
| 455 | + return dstream.map(lambda x: self._model.predict(x)) |
| 456 | + |
| 457 | + def predictOnValues(self, dstream): |
| 458 | + """ |
| 459 | + Make predictions on a keyed dstream. |
| 460 | + Returns a transformed dstream object. |
| 461 | + """ |
| 462 | + self._validate(dstream) |
| 463 | + return dstream.mapValues(lambda x: self._model.predict(x)) |
| 464 | + |
| 465 | + |
267 | 466 | def _test(): |
268 | 467 | import doctest |
269 | 468 | globs = globals().copy() |
|
0 commit comments