Skip to content

Commit daec9fe

Browse files
break before exceeding array size
1 parent 94d5c8c commit daec9fe

File tree

1 file changed

+5
-1
lines changed

1 file changed

+5
-1
lines changed

ot/lp/emd_wrap.pyx

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,7 @@ def emd_1d_sorted(np.ndarray[double, ndim=1, mode="c"] u_weights,
172172
cdef np.ndarray[long, ndim=2, mode="c"] indices = np.zeros((n + m - 1, 2),
173173
dtype=np.int)
174174
cdef int cur_idx = 0
175-
while i < n and j < m:
175+
while True:
176176
if metric == 'sqeuclidean':
177177
m_ij = (u[i] - v[j]) * (u[i] - v[j])
178178
elif metric == 'cityblock' or metric == 'euclidean':
@@ -188,6 +188,8 @@ def emd_1d_sorted(np.ndarray[double, ndim=1, mode="c"] u_weights,
188188
indices[cur_idx, 0] = i
189189
indices[cur_idx, 1] = j
190190
i += 1
191+
if i == n:
192+
break
191193
w_j -= w_i
192194
w_i = u_weights[i]
193195
else:
@@ -196,6 +198,8 @@ def emd_1d_sorted(np.ndarray[double, ndim=1, mode="c"] u_weights,
196198
indices[cur_idx, 0] = i
197199
indices[cur_idx, 1] = j
198200
j += 1
201+
if j == m:
202+
break
199203
w_i -= w_j
200204
w_j = v_weights[j]
201205
cur_idx += 1

0 commit comments

Comments
 (0)