Skip to content

Commit fa8bf65

Browse files
committed
Support parameter positive for QuantileLinearRegression
1 parent 7959664 commit fa8bf65

File tree

2 files changed

+37
-5
lines changed

2 files changed

+37
-5
lines changed

_unittests/ut_mlmodel/test_quantile_regression.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,22 @@ def test_quantile_regression_no_intercept(self):
3232
self.assertEqual(clq.intercept_, 0)
3333
self.assertEqualArray(clr.intercept_, clq.intercept_)
3434

35+
def test_quantile_regression_no_intercept_positive(self):
36+
X = numpy.array([[0.1, 0.2], [0.2, 0.3]])
37+
Y = numpy.array([1., 1.1])
38+
clr = LinearRegression(fit_intercept=False, positive=True)
39+
clr.fit(X, Y)
40+
clq = QuantileLinearRegression(fit_intercept=False, positive=True)
41+
clq.fit(X, Y)
42+
self.assertEqual(clr.intercept_, 0)
43+
self.assertEqual(clq.intercept_, 0)
44+
self.assertGreater(clr.coef_.min(), 0)
45+
self.assertGreater(clq.coef_.min(), 0)
46+
self.assertEqualArray(clr.intercept_, clq.intercept_)
47+
self.assertEqualArray(clr.coef_[0], clq.coef_[0])
48+
self.assertGreater(clr.coef_[1:].min(), 3)
49+
self.assertGreater(clq.coef_[1:].min(), 3)
50+
3551
def test_quantile_regression_intercept(self):
3652
X = numpy.array([[0.1, 0.2], [0.2, 0.3], [0.3, 0.3]])
3753
Y = numpy.array([1., 1.1, 1.2])
@@ -44,6 +60,21 @@ def test_quantile_regression_intercept(self):
4460
self.assertEqualArray(clr.intercept_, clq.intercept_)
4561
self.assertEqualArray(clr.coef_, clq.coef_)
4662

63+
def test_quantile_regression_intercept_positive(self):
64+
X = numpy.array([[0.1, 0.2], [0.2, 0.3], [0.3, 0.3]])
65+
Y = numpy.array([1., 1.1, 1.2])
66+
clr = LinearRegression(fit_intercept=True, positive=True)
67+
clr.fit(X, Y)
68+
clq = QuantileLinearRegression(
69+
verbose=False, fit_intercept=True, positive=True)
70+
clq.fit(X, Y)
71+
self.assertNotEqual(clr.intercept_, 0)
72+
self.assertNotEqual(clq.intercept_, 0)
73+
self.assertEqualArray(clr.intercept_, clq.intercept_)
74+
self.assertEqualArray(clr.coef_, clq.coef_)
75+
self.assertGreater(clr.coef_.min(), 0)
76+
self.assertGreater(clq.coef_.min(), 0)
77+
4778
def test_quantile_regression_intercept_weights(self):
4879
X = numpy.array([[0.1, 0.2], [0.2, 0.3], [0.3, 0.3]])
4980
Y = numpy.array([1., 1.1, 1.2])

mlinsights/mlmodel/quantile_regression.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ class QuantileLinearRegression(LinearRegression):
3131

3232
def __init__(self, fit_intercept=True, normalize=False, copy_X=True,
3333
n_jobs=1, delta=0.0001, max_iter=10, quantile=0.5,
34-
verbose=False):
34+
positive=False, verbose=False):
3535
"""
3636
:param fit_intercept: boolean, optional, default True
3737
whether to calculate the intercept for this model. If set
@@ -59,11 +59,12 @@ def __init__(self, fit_intercept=True, normalize=False, copy_X=True,
5959
:param quantile: float, by default 0.5,
6060
determines which quantile to use
6161
to estimate the regression.
62+
:param positive: when set to True, forces the coefficients to be positive.
6263
:param verbose: bool, optional, default False
6364
Prints error at each iteration of the optimisation.
6465
"""
6566
LinearRegression.__init__(self, fit_intercept=fit_intercept, normalize=normalize,
66-
copy_X=copy_X, n_jobs=n_jobs)
67+
copy_X=copy_X, n_jobs=n_jobs, positive=positive)
6768
self.max_iter = max_iter
6869
self.verbose = verbose
6970
self.delta = delta
@@ -131,7 +132,8 @@ def compute_z(Xm, beta, Y, W, delta=0.0001):
131132
Xm = X
132133

133134
clr = LinearRegression(fit_intercept=False, copy_X=self.copy_X,
134-
n_jobs=self.n_jobs, normalize=self.normalize)
135+
n_jobs=self.n_jobs, normalize=self.normalize,
136+
positive=self.positive)
135137

136138
W = numpy.ones(X.shape[0]) if sample_weight is None else sample_weight
137139
self.n_iter_ = 0
@@ -197,5 +199,4 @@ def score(self, X, y, sample_weight=None):
197199
if mult is not None:
198200
epsilon *= mult * 2
199201
return epsilon.sum() / X.shape[0]
200-
else:
201-
return mean_absolute_error(y, pred, sample_weight=sample_weight)
202+
return mean_absolute_error(y, pred, sample_weight=sample_weight)

0 commit comments

Comments
 (0)