@@ -26,7 +26,7 @@ def test_criterions(self):
2626 X = numpy .array ([[1. , 2. ]]).T
2727 y = numpy .array ([1. , 2. ])
2828 c1 = MSE (1 , X .shape [0 ])
29- c2 = SimpleRegressorCriterion (X )
29+ c2 = SimpleRegressorCriterion (1 , X . shape [ 0 ] )
3030 self .assertNotEmpty (c1 )
3131 self .assertNotEmpty (c2 )
3232 w = numpy .ones ((y .shape [0 ],))
@@ -49,7 +49,7 @@ def test_criterions(self):
4949 X = numpy .array ([[1. , 2. , 3. ]]).T
5050 y = numpy .array ([1. , 2. , 3. ])
5151 c1 = MSE (1 , X .shape [0 ])
52- c2 = SimpleRegressorCriterion (X )
52+ c2 = SimpleRegressorCriterion (1 , X . shape [ 0 ] )
5353 w = numpy .ones ((y .shape [0 ],))
5454 ind = numpy .arange (y .shape [0 ]).astype (numpy .int64 )
5555 ys = y .astype (float ).reshape ((y .shape [0 ], 1 ))
@@ -68,7 +68,7 @@ def test_criterions(self):
6868 X = numpy .array ([[1. , 2. , 10. , 11. ]]).T
6969 y = numpy .array ([0.9 , 1.1 , 1.9 , 2.1 ])
7070 c1 = MSE (1 , X .shape [0 ])
71- c2 = SimpleRegressorCriterion (X )
71+ c2 = SimpleRegressorCriterion (1 , X . shape [ 0 ] )
7272 w = numpy .ones ((y .shape [0 ],))
7373 ind = numpy .arange (y .shape [0 ]).astype (numpy .int64 )
7474 ys = y .astype (float ).reshape ((y .shape [0 ], 1 ))
@@ -121,7 +121,7 @@ def test_criterions(self):
121121 X = numpy .array ([[1. , 2. , 10. , 11. ]]).T
122122 y = numpy .array ([0.9 , 1.1 , 1.9 , 2.1 ])
123123 c1 = MSE (1 , X .shape [0 ])
124- c2 = SimpleRegressorCriterion (X )
124+ c2 = SimpleRegressorCriterion (1 , X . shape [ 0 ] )
125125 w = numpy .ones ((y .shape [0 ],))
126126 ind = numpy .array ([0 , 3 , 2 , 1 ], dtype = ind .dtype )
127127 ys = y .astype (float ).reshape ((y .shape [0 ], 1 ))
@@ -166,7 +166,8 @@ def test_decision_tree_criterion(self):
166166 clr1 .fit (X , y )
167167 p1 = clr1 .predict (X )
168168
169- crit = SimpleRegressorCriterion (X )
169+ crit = SimpleRegressorCriterion (
170+ 1 if len (y .shape ) <= 1 else y .shape [1 ], X .shape [0 ])
170171 clr2 = DecisionTreeRegressor (criterion = crit , max_depth = 1 )
171172 clr2 .fit (X , y )
172173 p2 = clr2 .predict (X )
@@ -179,7 +180,9 @@ def test_decision_tree_criterion_iris(self):
179180 clr1 = DecisionTreeRegressor ()
180181 clr1 .fit (X , y )
181182 p1 = clr1 .predict (X )
182- clr2 = DecisionTreeRegressor (criterion = SimpleRegressorCriterion (X ))
183+ clr2 = DecisionTreeRegressor (
184+ criterion = SimpleRegressorCriterion (
185+ 1 if len (y .shape ) <= 1 else y .shape [1 ], X .shape [0 ]))
183186 clr2 .fit (X , y )
184187 p2 = clr2 .predict (X )
185188 self .assertEqual (p1 [:10 ], p2 [:10 ])
0 commit comments