Skip to content

Commit 4af4b35

Browse files
committed
fix test issues
1 parent 320203e commit 4af4b35

File tree

2 files changed

+14
-20
lines changed

2 files changed

+14
-20
lines changed

python/pyspark/ml/tests.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -538,6 +538,19 @@ def test_rformula_force_index_label(self):
538538
transformedDF2 = model2.transform(df)
539539
self.assertEqual(transformedDF2.head().label, 0.0)
540540

541+
def test_rformula_string_indexer_order_type(self):
542+
df = self.spark.createDataFrame([
543+
(1.0, 1.0, "a"),
544+
(0.0, 2.0, "b"),
545+
(1.0, 0.0, "a")], ["y", "x", "s"])
546+
rf = RFormula(formula="y ~ x + s", stringIndexerOrderType="alphabetDesc")
547+
self.assertEqual(rf.getStringIndexerOrderType(), 'alphabetDesc')
548+
transformedDF = rf.fit(df).transform(df)
549+
observed = transformedDF.select("features").collect()
550+
expected = [[1.0, 0.0], [2.0, 1.0], [0.0, 0.0]]
551+
for i in range(0, len(expected)):
552+
self.assertTrue((observed[i]["features"].toArray() == expected[i]).all())
553+
541554

542555
class HasInducedError(Params):
543556

python/pyspark/tests.py

Lines changed: 1 addition & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -61,9 +61,8 @@
6161
from pyspark import keyword_only
6262
from pyspark.conf import SparkConf
6363
from pyspark.context import SparkContext
64-
from pyspark.files import SparkFiles
65-
from pyspark.ml.feature import RFormula
6664
from pyspark.rdd import RDD
65+
from pyspark.files import SparkFiles
6766
from pyspark.serializers import read_int, BatchedSerializer, MarshalSerializer, PickleSerializer, \
6867
CloudPickleSerializer, CompressedSerializer, UTF8Deserializer, NoOpSerializer, \
6968
PairDeserializer, CartesianDeserializer, AutoBatchedSerializer, AutoSerializer, \
@@ -2207,24 +2206,6 @@ def set(self, x=None, other=None, other_x=None):
22072206
self.assertEqual(b._x, 2)
22082207

22092208

2210-
class SparkMLTests(ReusedPySparkTestCase):
2211-
2212-
def test_rformula(self):
2213-
df = self.sc.parallelize([
2214-
(1.0, 1.0, "a"),
2215-
(0.0, 2.0, "b"),
2216-
(0.0, 0.0, "a")
2217-
]).toDF(["y", "x", "s"])
2218-
rf = RFormula(formula="y ~ x + s", stringIndexerOrderType="alphabetDesc")
2219-
self.assertEqual(rf.getStringIndexerOrderType(), 'alphabetDesc')
2220-
2221-
result = rf.fit(df).transform(df)
2222-
observed = result.select("features").collect()
2223-
expected = [[1.0, 0.0], [2.0, 1.0], [0.0, 0.0]]
2224-
for i in range(0, len(expected)):
2225-
self.assertEqual(observed[i]["features"].toArray(), expected[i])
2226-
2227-
22282209
@unittest.skipIf(not _have_scipy, "SciPy not installed")
22292210
class SciPyTests(PySparkTestCase):
22302211

0 commit comments

Comments
 (0)