Skip to content

Commit ff5676b

Browse files
actuaryzhangyanboliang
authored andcommitted
[SPARK-20899][PYSPARK] PySpark supports stringIndexerOrderType in RFormula
## What changes were proposed in this pull request? PySpark supports stringIndexerOrderType in RFormula as in #17967. ## How was this patch tested? docstring test Author: actuaryzhang <[email protected]> Closes #18122 from actuaryzhang/PythonRFormula.
1 parent 35b644b commit ff5676b

File tree

2 files changed

+41
-5
lines changed

2 files changed

+41
-5
lines changed

python/pyspark/ml/feature.py

Lines changed: 28 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3043,26 +3043,35 @@ class RFormula(JavaEstimator, HasFeaturesCol, HasLabelCol, JavaMLReadable, JavaM
30433043
"Force to index label whether it is numeric or string",
30443044
typeConverter=TypeConverters.toBoolean)
30453045

3046+
stringIndexerOrderType = Param(Params._dummy(), "stringIndexerOrderType",
3047+
"How to order categories of a string feature column used by " +
3048+
"StringIndexer. The last category after ordering is dropped " +
3049+
"when encoding strings. Supported options: frequencyDesc, " +
3050+
"frequencyAsc, alphabetDesc, alphabetAsc. The default value " +
3051+
"is frequencyDesc. When the ordering is set to alphabetDesc, " +
3052+
"RFormula drops the same category as R when encoding strings.",
3053+
typeConverter=TypeConverters.toString)
3054+
30463055
@keyword_only
30473056
def __init__(self, formula=None, featuresCol="features", labelCol="label",
3048-
forceIndexLabel=False):
3057+
forceIndexLabel=False, stringIndexerOrderType="frequencyDesc"):
30493058
"""
30503059
__init__(self, formula=None, featuresCol="features", labelCol="label", \
3051-
forceIndexLabel=False)
3060+
forceIndexLabel=False, stringIndexerOrderType="frequencyDesc")
30523061
"""
30533062
super(RFormula, self).__init__()
30543063
self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.RFormula", self.uid)
3055-
self._setDefault(forceIndexLabel=False)
3064+
self._setDefault(forceIndexLabel=False, stringIndexerOrderType="frequencyDesc")
30563065
kwargs = self._input_kwargs
30573066
self.setParams(**kwargs)
30583067

30593068
@keyword_only
30603069
@since("1.5.0")
30613070
def setParams(self, formula=None, featuresCol="features", labelCol="label",
3062-
forceIndexLabel=False):
3071+
forceIndexLabel=False, stringIndexerOrderType="frequencyDesc"):
30633072
"""
30643073
setParams(self, formula=None, featuresCol="features", labelCol="label", \
3065-
forceIndexLabel=False)
3074+
forceIndexLabel=False, stringIndexerOrderType="frequencyDesc")
30663075
Sets params for RFormula.
30673076
"""
30683077
kwargs = self._input_kwargs
@@ -3096,6 +3105,20 @@ def getForceIndexLabel(self):
30963105
"""
30973106
return self.getOrDefault(self.forceIndexLabel)
30983107

3108+
@since("2.3.0")
3109+
def setStringIndexerOrderType(self, value):
3110+
"""
3111+
Sets the value of :py:attr:`stringIndexerOrderType`.
3112+
"""
3113+
return self._set(stringIndexerOrderType=value)
3114+
3115+
@since("2.3.0")
3116+
def getStringIndexerOrderType(self):
3117+
"""
3118+
Gets the value of :py:attr:`stringIndexerOrderType` or its default value 'frequencyDesc'.
3119+
"""
3120+
return self.getOrDefault(self.stringIndexerOrderType)
3121+
30993122
def _create_model(self, java_model):
31003123
return RFormulaModel(java_model)
31013124

python/pyspark/ml/tests.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -538,6 +538,19 @@ def test_rformula_force_index_label(self):
538538
transformedDF2 = model2.transform(df)
539539
self.assertEqual(transformedDF2.head().label, 0.0)
540540

541+
def test_rformula_string_indexer_order_type(self):
542+
df = self.spark.createDataFrame([
543+
(1.0, 1.0, "a"),
544+
(0.0, 2.0, "b"),
545+
(1.0, 0.0, "a")], ["y", "x", "s"])
546+
rf = RFormula(formula="y ~ x + s", stringIndexerOrderType="alphabetDesc")
547+
self.assertEqual(rf.getStringIndexerOrderType(), 'alphabetDesc')
548+
transformedDF = rf.fit(df).transform(df)
549+
observed = transformedDF.select("features").collect()
550+
expected = [[1.0, 0.0], [2.0, 1.0], [0.0, 0.0]]
551+
for i in range(0, len(expected)):
552+
self.assertTrue(all(observed[i]["features"].toArray() == expected[i]))
553+
541554

542555
class HasInducedError(Params):
543556

0 commit comments

Comments
 (0)