|
17 | 17 |
|
18 | 18 | import operator |
19 | 19 | import warnings |
20 | | -from multiprocessing.dummy import Pool |
21 | 20 |
|
22 | 21 | from pyspark.ml import Estimator, Model |
23 | 22 | from pyspark.ml.param.shared import * |
@@ -1202,6 +1201,9 @@ def setParams(self, featuresCol=None, labelCol=None, predictionCol=None, classif |
1202 | 1201 | def setClassifier(self, value): |
1203 | 1202 | """ |
1204 | 1203 | Sets the value of :py:attr:`classifier`. |
| 1204 | +
|
| 1205 | + .. note:: Only LogisticRegression, NaiveBayes and MultilayerPerceptronClassifier are |
| 1206 | + supported now. |
1205 | 1207 | """ |
1206 | 1208 | self._paramMap[self.classifier] = value |
1207 | 1209 | return self |
@@ -1237,13 +1239,10 @@ def trainSingleClass(index): |
1237 | 1239 | paramMap = dict([(classifier.labelCol, binaryLabelCol), |
1238 | 1240 | (classifier.featuresCol, featuresCol), |
1239 | 1241 | (classifier.predictionCol, predictionCol)]) |
1240 | | - duplicatedClassifier = classifier.__class__() |
1241 | | - duplicatedClassifier._resetUid(classifier.uid) |
1242 | | - classifier._copyValues(duplicatedClassifier) |
1243 | | - return duplicatedClassifier.fit(trainingDataset, paramMap) |
| 1242 | + return classifier.fit(trainingDataset, paramMap) |
1244 | 1243 |
|
1245 | | - pool = Pool() |
1246 | | - models = pool.map(trainSingleClass, range(numClasses)) |
| 1244 | + # TODO: Parallel training for all classes. |
| 1245 | + models = [trainSingleClass(i) for i in range(numClasses)] |
1247 | 1246 |
|
1248 | 1247 | if handlePersistence: |
1249 | 1248 | multiclassLabeled.unpersist() |
|
0 commit comments