diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index 04b07e6a05481..a444fe0f98a4f 100755 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -3673,12 +3673,12 @@ def setParams(self, inputCol=None, size=None, handleInvalid="error"): @since("2.3.0") def getSize(self): """ Gets size param, the size of vectors in `inputCol`.""" - self.getOrDefault(self.size) + return self.getOrDefault(self.size) @since("2.3.0") def setSize(self, value): """ Sets size param, the size of vectors in `inputCol`.""" - self._set(size=value) + return self._set(size=value) if __name__ == "__main__": diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index 1af2b91da900d..49912d20da1d8 100755 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -678,6 +678,23 @@ def test_string_indexer_handle_invalid(self): expected2 = [Row(id=0, indexed=0.0), Row(id=1, indexed=1.0)] self.assertEqual(actual2, expected2) + def test_vector_size_hint(self): + df = self.spark.createDataFrame( + [(0, Vectors.dense([0.0, 10.0, 0.5])), + (1, Vectors.dense([1.0, 11.0, 0.5, 0.6])), + (2, Vectors.dense([2.0, 12.0]))], + ["id", "vector"]) + + sizeHint = VectorSizeHint( + inputCol="vector", + handleInvalid="skip") + sizeHint.setSize(3) + self.assertEqual(sizeHint.getSize(), 3) + + output = sizeHint.transform(df).head().vector + expected = DenseVector([0.0, 10.0, 0.5]) + self.assertEqual(output, expected) + class HasInducedError(Params):