-
Notifications
You must be signed in to change notification settings - Fork 28.9k
[SPARK-14894][PySpark] Add result summary api to Gaussian Mixture #12675
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
c42f5dd
7db9c0d
b19756c
7d16a23
3cf080a
c2b1aef
3bc75a3
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -17,7 +17,7 @@ | |
|
|
||
| from pyspark import since, keyword_only | ||
| from pyspark.ml.util import * | ||
| from pyspark.ml.wrapper import JavaEstimator, JavaModel | ||
| from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaWrapper | ||
| from pyspark.ml.param.shared import * | ||
| from pyspark.ml.common import inherit_doc | ||
|
|
||
|
|
@@ -56,8 +56,83 @@ def gaussiansDF(self): | |
| """ | ||
| return self._call_java("gaussiansDF") | ||
|
|
||
| @property | ||
| @since("2.0.0") | ||
| def summary(self): | ||
| """ | ||
|
||
| Gets summary of model on training set. An exception is thrown if | ||
| `trainingSummary is None`. | ||
|
||
| """ | ||
| java_gmt_summary = self._call_java("summary") | ||
| return GaussianMixtureSummary(java_gmt_summary) | ||
|
|
||
| @property | ||
| @since("2.0.0") | ||
| def hasSummary(self): | ||
| """ | ||
| Indicates whether a training summary exists for this model | ||
| instance. | ||
| """ | ||
| return self._call_java("hasSummary") | ||
|
|
||
|
|
||
| class GaussianMixtureSummary(JavaWrapper): | ||
| """ | ||
| Abstraction for Gaussian Mixture Results for a given model. | ||
|
|
||
| .. versionadded:: 2.0.0 | ||
| """ | ||
|
|
||
| @property | ||
| @since("2.0.0") | ||
| def predictions(self): | ||
| """ | ||
| Dataframe outputted by the model's `transform` method. | ||
| """ | ||
| return self._call_java("predictions") | ||
|
|
||
| @property | ||
| @since("2.0.0") | ||
| def probabilityCol(self): | ||
| """ | ||
| Field in "predictions" which gives the probability | ||
| of each class. | ||
| """ | ||
| return self._call_java("probabilityCol") | ||
|
|
||
| @property | ||
| @since("2.0.0") | ||
| def featuresCol(self): | ||
| """ | ||
| Field in "predictions" which gives the features of each instance. | ||
| """ | ||
| return self._call_java("featuresCol") | ||
|
|
||
| @property | ||
| @since("2.0.0") | ||
| def cluster(self): | ||
| """ | ||
| Cluster centers of the transformed data. | ||
| """ | ||
| return self._call_java("cluster") | ||
|
|
||
| @property | ||
| @since("2.0.0") | ||
| def probability(self): | ||
| """ | ||
| Probability of each cluster. | ||
| """ | ||
| return self._call_java("probability") | ||
|
|
||
| @property | ||
| @since("2.0.0") | ||
| def clusterSizes(self): | ||
| """ | ||
| Size of (number of data points in) each cluster. | ||
| """ | ||
| return self._call_java("clusterSizes") | ||
|
|
||
|
|
||
| @inherit_doc | ||
| class GaussianMixture(JavaEstimator, HasFeaturesCol, HasPredictionCol, HasMaxIter, HasTol, HasSeed, | ||
| HasProbabilityCol, JavaMLWritable, JavaMLReadable): | ||
| """ | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1070,6 +1070,21 @@ def test_logistic_regression_summary(self): | |
| sameSummary = model.evaluate(df) | ||
| self.assertAlmostEqual(sameSummary.areaUnderROC, s.areaUnderROC) | ||
|
|
||
| def test_gaussian_mixture_summary(self): | ||
| from pyspark.mllib.linalg import Vectors | ||
| df = self.spark.read.format("libsvm").load("data/mllib/sample_kmeans_data.txt") | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please let me know if its ok to load data from a file when all other test cases uses hard coded data values. I tried with a sparse vector and fit gave me an error for the data format.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think its ok to load a dataset which is already used for testing else where. |
||
| gm = GaussianMixture(k=3, tol=0.0001, maxIter=10, seed=10) | ||
| model = gm.fit(df) | ||
| self.assertTrue(model.hasSummary) | ||
| s = model.summary | ||
| # test that api is callable and returns expected types | ||
| self.assertTrue(isinstance(s.predictions, DataFrame)) | ||
| self.assertEqual(s.featuresCol, "features") | ||
| cluster_sizes = s.clusterSizes | ||
| self.assertTrue(isinstance(s.cluster, DataFrame)) | ||
| self.assertTrue(isinstance(s.probability, DataFrame)) | ||
| self.assertTrue(isinstance(cluster_sizes[0], int)) | ||
|
|
||
|
|
||
| class OneVsRestTests(SparkSessionTestCase): | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since we've shipped 2.0 we will need to update this to 2.1 (same with the versionAdded notes)