Skip to content

Commit cf4df64

Browse files
committed
revert non-parallel process
1 parent ecdc742 commit cf4df64

File tree

1 file changed

+6
-7
lines changed

1 file changed

+6
-7
lines changed

python/pyspark/ml/classification.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717

1818
import operator
1919
import warnings
20-
from multiprocessing.dummy import Pool
2120

2221
from pyspark.ml import Estimator, Model
2322
from pyspark.ml.param.shared import *
@@ -1202,6 +1201,9 @@ def setParams(self, featuresCol=None, labelCol=None, predictionCol=None, classif
12021201
def setClassifier(self, value):
12031202
"""
12041203
Sets the value of :py:attr:`classifier`.
1204+
1205+
.. note:: Only LogisticRegression, NaiveBayes and MultilayerPerceptronClassifier are
1206+
supported now.
12051207
"""
12061208
self._paramMap[self.classifier] = value
12071209
return self
@@ -1237,13 +1239,10 @@ def trainSingleClass(index):
12371239
paramMap = dict([(classifier.labelCol, binaryLabelCol),
12381240
(classifier.featuresCol, featuresCol),
12391241
(classifier.predictionCol, predictionCol)])
1240-
duplicatedClassifier = classifier.__class__()
1241-
duplicatedClassifier._resetUid(classifier.uid)
1242-
classifier._copyValues(duplicatedClassifier)
1243-
return duplicatedClassifier.fit(trainingDataset, paramMap)
1242+
return classifier.fit(trainingDataset, paramMap)
12441243

1245-
pool = Pool()
1246-
models = pool.map(trainSingleClass, range(numClasses))
1244+
# TODO: Parallel training for all classes.
1245+
models = [trainSingleClass(i) for i in range(numClasses)]
12471246

12481247
if handlePersistence:
12491248
multiclassLabeled.unpersist()

0 commit comments

Comments
 (0)