Skip to content

Commit 1ffa8cb

Browse files
committed
[SPARK-7329] [MLLIB] simplify ParamGridBuilder impl
as suggested by justinuang on apache#5601. Author: Xiangrui Meng <[email protected]> Closes apache#5873 from mengxr/SPARK-7329 and squashes the following commits: d08f9cf [Xiangrui Meng] simplify tests b7a7b9b [Xiangrui Meng] simplify grid build
1 parent 9e25b09 commit 1ffa8cb

File tree

1 file changed

+9
-19
lines changed

1 file changed

+9
-19
lines changed

python/pyspark/ml/tuning.py

Lines changed: 9 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
# limitations under the License.
1616
#
1717

18+
import itertools
19+
1820
__all__ = ['ParamGridBuilder']
1921

2022

@@ -37,14 +39,10 @@ class ParamGridBuilder(object):
3739
{lr.regParam: 1.0, lr.featuresCol: 'f', lr.maxIter: 5, lr.labelCol: 'l', lr.predictionCol: 'p'}, \
3840
{lr.regParam: 2.0, lr.featuresCol: 'f', lr.maxIter: 5, lr.labelCol: 'l', lr.predictionCol: 'p'}, \
3941
{lr.regParam: 3.0, lr.featuresCol: 'f', lr.maxIter: 5, lr.labelCol: 'l', lr.predictionCol: 'p'}]
40-
>>> fail_count = 0
41-
>>> for e in expected:
42-
... if e not in output:
43-
... fail_count += 1
44-
>>> if len(expected) != len(output):
45-
... fail_count += 1
46-
>>> fail_count
47-
0
42+
>>> len(output) == len(expected)
43+
True
44+
>>> all([m in expected for m in output])
45+
True
4846
"""
4947

5048
def __init__(self):
@@ -76,17 +74,9 @@ def build(self):
7674
Builds and returns all combinations of parameters specified
7775
by the param grid.
7876
"""
79-
param_maps = [{}]
80-
for (param, values) in self._param_grid.items():
81-
new_param_maps = []
82-
for value in values:
83-
for old_map in param_maps:
84-
copied_map = old_map.copy()
85-
copied_map[param] = value
86-
new_param_maps.append(copied_map)
87-
param_maps = new_param_maps
88-
89-
return param_maps
77+
keys = self._param_grid.keys()
78+
grid_values = self._param_grid.values()
79+
return [dict(zip(keys, prod)) for prod in itertools.product(*grid_values)]
9080

9181

9282
if __name__ == "__main__":

0 commit comments

Comments
 (0)