Skip to content

Commit 7dde9e8

Browse files
authored
[MRG] Regularized OT (optim.cg) bug solve (#286)
* Line search stops when derphi is 0 instead of bugging out like in some instances * pep8 compliance * Tests
1 parent e0ba31c commit 7dde9e8

File tree

3 files changed

+39
-4
lines changed

3 files changed

+39
-4
lines changed

ot/optim.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -178,9 +178,9 @@ def cg(a, b, M, reg, f, df, G0=None, numItermax=200, numItermaxEmd=100000,
178178
numItermaxEmd : int, optional
179179
Max number of iterations for emd
180180
stopThr : float, optional
181-
Stop threshol on the relative variation (>0)
181+
Stop threshold on the relative variation (>0)
182182
stopThr2 : float, optional
183-
Stop threshol on the absolute variation (>0)
183+
Stop threshold on the absolute variation (>0)
184184
verbose : bool, optional
185185
Print information along iterations
186186
log : bool, optional
@@ -249,6 +249,8 @@ def cost(G):
249249

250250
# line search
251251
alpha, fc, f_val = solve_linesearch(cost, G, deltaG, Mi, f_val, reg=reg, M=M, Gc=Gc, **kwargs)
252+
if alpha is None:
253+
alpha = 0.0
252254

253255
G = G + alpha * deltaG
254256

@@ -320,9 +322,9 @@ def gcg(a, b, M, reg1, reg2, f, df, G0=None, numItermax=10,
320322
numInnerItermax : int, optional
321323
Max number of iterations of Sinkhorn
322324
stopThr : float, optional
323-
Stop threshol on the relative variation (>0)
325+
Stop threshold on the relative variation (>0)
324326
stopThr2 : float, optional
325-
Stop threshol on the absolute variation (>0)
327+
Stop threshold on the absolute variation (>0)
326328
verbose : bool, optional
327329
Print information along iterations
328330
log : bool, optional

test/test_da.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -565,6 +565,14 @@ def test_mapping_transport_class():
565565
otda.fit(Xs=Xs, Xt=Xt)
566566
assert len(otda.log_.keys()) != 0
567567

568+
# check that it does not crash when derphi is very close to 0
569+
np.random.seed(39)
570+
Xs, ys = make_data_classif('3gauss', ns)
571+
Xt, yt = make_data_classif('3gauss2', nt)
572+
otda = ot.da.MappingTransport(kernel="gaussian", bias=False)
573+
otda.fit(Xs=Xs, Xt=Xt)
574+
np.random.seed(None)
575+
568576

569577
def test_linear_mapping():
570578
ns = 150

test/test_optim.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,3 +114,28 @@ def test_line_search_armijo():
114114
# Should not throw an exception and return None for alpha
115115
alpha, _, _ = ot.optim.line_search_armijo(lambda x: 1, xk, pk, gfk, old_fval)
116116
assert alpha is None
117+
118+
# check line search armijo
119+
def f(x):
120+
return np.sum((x - 5.0) ** 2)
121+
122+
def grad(x):
123+
return 2 * (x - 5.0)
124+
125+
xk = np.array([[[-5.0, -5.0]]])
126+
pk = np.array([[[100.0, 100.0]]])
127+
gfk = grad(xk)
128+
old_fval = f(xk)
129+
130+
# chech the case where the optimum is on the direction
131+
alpha, _, _ = ot.optim.line_search_armijo(f, xk, pk, gfk, old_fval)
132+
np.testing.assert_allclose(alpha, 0.1)
133+
134+
# check the case where the direction is not far enough
135+
pk = np.array([[[3.0, 3.0]]])
136+
alpha, _, _ = ot.optim.line_search_armijo(f, xk, pk, gfk, old_fval, alpha0=1.0)
137+
np.testing.assert_allclose(alpha, 1.0)
138+
139+
# check the case where the checking the wrong direction
140+
alpha, _, _ = ot.optim.line_search_armijo(f, xk, -pk, gfk, old_fval)
141+
assert alpha <= 0

0 commit comments

Comments
 (0)