Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 77 additions & 2 deletions python/pyspark/ml/clustering.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

from pyspark import since, keyword_only
from pyspark.ml.util import *
from pyspark.ml.wrapper import JavaEstimator, JavaModel
from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaWrapper
from pyspark.ml.param.shared import *
from pyspark.ml.common import inherit_doc

Expand Down Expand Up @@ -56,8 +56,83 @@ def gaussiansDF(self):
"""
return self._call_java("gaussiansDF")

@property
@since("2.0.0")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we've shipped 2.0 we will need to update this to 2.1 (same with the versionAdded notes)

def summary(self):
"""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This PyDoc seems oddly formatted in terms of line breaks.

Gets summary of model on training set. An exception is thrown if
`trainingSummary is None`.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this also seems strangely formatted - did you mean trainingSummary is None. ?

"""
java_gmt_summary = self._call_java("summary")
return GaussianMixtureSummary(java_gmt_summary)

@property
@since("2.0.0")
def hasSummary(self):
"""
Indicates whether a training summary exists for this model
instance.
"""
return self._call_java("hasSummary")


class GaussianMixtureSummary(JavaWrapper):
"""
Abstraction for Gaussian Mixture Results for a given model.

.. versionadded:: 2.0.0
"""

@property
@since("2.0.0")
def predictions(self):
"""
Dataframe outputted by the model's `transform` method.
"""
return self._call_java("predictions")

@property
@since("2.0.0")
def probabilityCol(self):
"""
Field in "predictions" which gives the probability
of each class.
"""
return self._call_java("probabilityCol")

@property
@since("2.0.0")
def featuresCol(self):
"""
Field in "predictions" which gives the features of each instance.
"""
return self._call_java("featuresCol")

@property
@since("2.0.0")
def cluster(self):
"""
Cluster centers of the transformed data.
"""
return self._call_java("cluster")

@property
@since("2.0.0")
def probability(self):
"""
Probability of each cluster.
"""
return self._call_java("probability")

@property
@since("2.0.0")
def clusterSizes(self):
"""
Size of (number of data points in) each cluster.
"""
return self._call_java("clusterSizes")


@inherit_doc
class GaussianMixture(JavaEstimator, HasFeaturesCol, HasPredictionCol, HasMaxIter, HasTol, HasSeed,
HasProbabilityCol, JavaMLWritable, JavaMLReadable):
"""
Expand Down
15 changes: 15 additions & 0 deletions python/pyspark/ml/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -1070,6 +1070,21 @@ def test_logistic_regression_summary(self):
sameSummary = model.evaluate(df)
self.assertAlmostEqual(sameSummary.areaUnderROC, s.areaUnderROC)

def test_gaussian_mixture_summary(self):
from pyspark.mllib.linalg import Vectors
df = self.spark.read.format("libsvm").load("data/mllib/sample_kmeans_data.txt")
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please let me know if its ok to load data from a file when all other test cases uses hard coded data values. I tried with a sparse vector and fit gave me an error for the data format.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think its ok to load a dataset which is already used for testing else where.

gm = GaussianMixture(k=3, tol=0.0001, maxIter=10, seed=10)
model = gm.fit(df)
self.assertTrue(model.hasSummary)
s = model.summary
# test that api is callable and returns expected types
self.assertTrue(isinstance(s.predictions, DataFrame))
self.assertEqual(s.featuresCol, "features")
cluster_sizes = s.clusterSizes
self.assertTrue(isinstance(s.cluster, DataFrame))
self.assertTrue(isinstance(s.probability, DataFrame))
self.assertTrue(isinstance(cluster_sizes[0], int))


class OneVsRestTests(SparkSessionTestCase):

Expand Down