Skip to content

Commit 25d681f

Browse files
committed
ENH: python, cache weightCol
1 parent 931d02d commit 25d681f

File tree

2 files changed

+14
-1
lines changed

2 files changed

+14
-1
lines changed

python/pyspark/ml/classification.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1546,7 +1546,10 @@ def _fit(self, dataset):
15461546

15471547
numClasses = int(dataset.agg({labelCol: "max"}).head()["max("+labelCol+")"]) + 1
15481548

1549-
multiclassLabeled = dataset.select(labelCol, featuresCol)
1549+
if isinstance(classifier, HasWeightCol) and classifier.getWeightCol():
1550+
multiclassLabeled = dataset.select(labelCol, featuresCol, classifier.getWeightCol())
1551+
else:
1552+
multiclassLabeled = dataset.select(labelCol, featuresCol)
15501553

15511554
# persist if underlying dataset is not persistent.
15521555
handlePersistence = \

python/pyspark/ml/tests.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1255,6 +1255,16 @@ def test_output_columns(self):
12551255
output = model.transform(df)
12561256
self.assertEqual(output.columns, ["label", "features", "prediction"])
12571257

1258+
def test_cache_weightCol_if_necessary(self):
1259+
df = self.spark.createDataFrame([(0.0, Vectors.dense(1.0, 0.8), 1.0),
1260+
(1.0, Vectors.sparse(2, [], []), 1.0),
1261+
(2.0, Vectors.dense(0.5, 0.5), 1.0)],
1262+
["label", "features", "weight"])
1263+
lr = LogisticRegression(maxIter=5, regParam=0.01, weightCol="weight")
1264+
ovr = OneVsRest(classifier=lr)
1265+
model = ovr.fit(df)
1266+
self.assertIsNone(model)
1267+
12581268

12591269
class HashingTFTest(SparkSessionTestCase):
12601270

0 commit comments

Comments
 (0)