From 04c12a0b91c058fec187bb69c7a81a7e80c4dc22 Mon Sep 17 00:00:00 2001 From: nathanneike Date: Tue, 28 Oct 2025 16:01:32 +0100 Subject: [PATCH 1/5] Add sparse EMD solver with unit tests - Implement sparse bipartite graph EMD solver in C++ - Add Python bindings for sparse solver (emd_wrap.pyx, _network_simplex.py) - Add unit tests to verify sparse and dense solvers produce identical results - Tests use augmented k-NN approach to ensure fair comparison - Update setup.py to include sparse solver compilation Both test_emd_sparse_vs_dense() and test_emd2_sparse_vs_dense() verify: * Identical costs between sparse and dense solvers * Marginal constraint satisfaction for both solvers --- ot/lp/EMD.h | 18 ++ ot/lp/EMD_wrapper.cpp | 155 +++++++++++++++++ ot/lp/_network_simplex.py | 302 +++++++++++++++++++++++++++------- ot/lp/emd_wrap.pyx | 76 +++++++++ ot/lp/sparse_bipartitegraph.h | 281 +++++++++++++++++++++++++++++++ setup.py | 14 +- test/test_ot.py | 149 ++++++++++++++++- 7 files changed, 932 insertions(+), 63 deletions(-) create mode 100644 ot/lp/sparse_bipartitegraph.h diff --git a/ot/lp/EMD.h b/ot/lp/EMD.h index b56f0601b..efa839bcf 100644 --- a/ot/lp/EMD.h +++ b/ot/lp/EMD.h @@ -32,6 +32,24 @@ enum ProblemType { int EMD_wrap(int n1,int n2, double *X, double *Y,double *D, double *G, double* alpha, double* beta, double *cost, uint64_t maxIter); int EMD_wrap_omp(int n1,int n2, double *X, double *Y,double *D, double *G, double* alpha, double* beta, double *cost, uint64_t maxIter, int numThreads); +int EMD_wrap_sparse( + int n1, + int n2, + double *X, + double *Y, + uint64_t n_edges, // Number of edges in sparse graph + int64_t *edge_sources, // Source indices for each edge (n_edges) + int64_t *edge_targets, // Target indices for each edge (n_edges) + double *edge_costs, // Cost for each edge (n_edges) + int64_t *flow_sources_out, // Output: source indices of non-zero flows + int64_t *flow_targets_out, // Output: target indices of non-zero flows + double *flow_values_out, // Output: flow values + uint64_t *n_flows_out, + double *alpha, // Output: dual variables for sources (n1) + double *beta, // Output: dual variables for targets (n2) + double *cost, // Output: total transportation cost + uint64_t maxIter // Maximum iterations for solver +); #endif diff --git a/ot/lp/EMD_wrapper.cpp b/ot/lp/EMD_wrapper.cpp index 4aa5a6e72..7b4b9ed6e 100644 --- a/ot/lp/EMD_wrapper.cpp +++ b/ot/lp/EMD_wrapper.cpp @@ -15,8 +15,10 @@ #include "network_simplex_simple.h" #include "network_simplex_simple_omp.h" +#include "sparse_bipartitegraph.h" #include "EMD.h" #include +#include int EMD_wrap(int n1, int n2, double *X, double *Y, double *D, double *G, @@ -216,3 +218,156 @@ int EMD_wrap_omp(int n1, int n2, double *X, double *Y, double *D, double *G, return ret; } + +// ============================================================================ +// SPARSE VERSION: Accepts edge list instead of dense cost matrix +// ============================================================================ +int EMD_wrap_sparse( + int n1, + int n2, + double *X, + double *Y, + uint64_t n_edges, + int64_t *edge_sources, + int64_t *edge_targets, + double *edge_costs, + int64_t *flow_sources_out, + int64_t *flow_targets_out, + double *flow_values_out, + uint64_t *n_flows_out, + double *alpha, + double *beta, + double *cost, + uint64_t maxIter +) { + using namespace lemon; + + uint64_t n = 0; + for (int i = 0; i < n1; i++) { + double val = *(X + i); + if (val > 0) { + n++; + } else if (val < 0) { + return INFEASIBLE; + } + } + + uint64_t m = 0; + for (int i = 0; i < n2; i++) { + double val = *(Y + i); + if (val > 0) { + m++; + } else if (val < 0) { + return INFEASIBLE; + } + } + + std::vector indI(n); // indI[graph_idx] = original_source_idx + std::vector indJ(m); // indJ[graph_idx] = original_target_idx + std::vector weights1(n); // Source masses (positive only) + std::vector weights2(m); // Target masses (negative for demand) + + // Create reverse mapping: original_idx → graph_idx + std::vector source_to_graph(n1, -1); + std::vector target_to_graph(n2, -1); + + uint64_t cur = 0; + for (int i = 0; i < n1; i++) { + double val = *(X + i); + if (val > 0) { + weights1[cur] = val; // Store the mass + indI[cur] = i; // Forward map: graph → original + source_to_graph[i] = cur; // Reverse map: original → graph + cur++; + } + } + + cur = 0; + for (int i = 0; i < n2; i++) { + double val = *(Y + i); + if (val > 0) { + weights2[cur] = -val; + indJ[cur] = i; // Forward map: graph → original + target_to_graph[i] = cur; // Reverse map: original → graph + cur++; + } + } + + typedef SparseBipartiteDigraph Digraph; + DIGRAPH_TYPEDEFS(Digraph); + + Digraph di(n, m); + + std::vector> edges; // (source, target) pairs + std::vector edge_to_arc; // edge_to_arc[k] = arc ID for edge k + std::vector arc_costs; // arc_costs[arc_id] = cost (for O(1) lookup) + edges.reserve(n_edges); + edge_to_arc.reserve(n_edges); + + uint64_t valid_edge_count = 0; + for (uint64_t k = 0; k < n_edges; k++) { + int64_t src_orig = edge_sources[k]; + int64_t tgt_orig = edge_targets[k]; + int64_t src = source_to_graph[src_orig]; + int64_t tgt = target_to_graph[tgt_orig]; + + if (src >= 0 && tgt >= 0) { + edges.emplace_back(src, tgt + n); + edge_to_arc.push_back(valid_edge_count); + arc_costs.push_back(edge_costs[k]); // Store cost indexed by arc ID + valid_edge_count++; + } else { + edge_to_arc.push_back(UINT64_MAX); + } + } + + + di.buildFromEdges(edges); + + NetworkSimplexSimple net( + di, true, (int)(n + m), di.arcNum(), maxIter + ); + + net.supplyMap(&weights1[0], (int)n, &weights2[0], (int)m); + + for (uint64_t k = 0; k < n_edges; k++) { + if (edge_to_arc[k] != UINT64_MAX) { + net.setCost(edge_to_arc[k], edge_costs[k]); + } + } + + int ret = net.run(); + + if (ret == (int)net.OPTIMAL || ret == (int)net.MAX_ITER_REACHED) { + *cost = 0; + *n_flows_out = 0; + + Arc a; + di.first(a); + for (; a != INVALID; di.next(a)) { + uint64_t i = di.source(a); + uint64_t j = di.target(a); + double flow = net.flow(a); + + uint64_t orig_i = indI[i]; + uint64_t orig_j = indJ[j - n]; + + + double arc_cost = arc_costs[a]; + + *cost += flow * arc_cost; + + + *(alpha + orig_i) = -net.potential(i); + *(beta + orig_j) = net.potential(j); + + if (flow > 1e-15) { + flow_sources_out[*n_flows_out] = orig_i; + flow_targets_out[*n_flows_out] = orig_j; + flow_values_out[*n_flows_out] = flow; + (*n_flows_out)++; + } + } + } + return ret; +} \ No newline at end of file diff --git a/ot/lp/_network_simplex.py b/ot/lp/_network_simplex.py index 492e4c7ac..3ce63a874 100644 --- a/ot/lp/_network_simplex.py +++ b/ot/lp/_network_simplex.py @@ -11,9 +11,11 @@ import numpy as np import warnings +import scipy.sparse as sp +import time from ..utils import list_to_array, check_number_threads from ..backend import get_backend -from .emd_wrap import emd_c, check_result +from .emd_wrap import emd_c, emd_c_sparse, check_result def center_ot_dual(alpha0, beta0, a=None, b=None): @@ -172,6 +174,8 @@ def emd( center_dual=True, numThreads=1, check_marginals=True, + sparse=False, + return_matrix=False, ): r"""Solves the Earth Movers distance problem and returns the OT matrix @@ -232,6 +236,12 @@ def emd( check_marginals: bool, optional (default=True) If True, checks that the marginals mass are equal. If False, skips the check. + sparse: bool, optional (default=False) + If True, uses the sparse solver that only stores edges with finite costs. + When sparse=True, M should be a scipy.sparse matrix. + return_matrix: bool, optional (default=True) + If True, returns the transport matrix. If False and sparse=True, returns + sparse flow representation in log. Returns @@ -272,38 +282,64 @@ def emd( ot.optim.cg : General regularized OT """ - a, b, M = list_to_array(a, b, M) - nx = get_backend(M, a, b) + edge_sources = None + edge_targets = None + edge_costs = None + n1, n2 = None, None + + if sparse: + if sp.issparse(M): + if not isinstance(M, sp.coo_matrix): + M_coo = sp.coo_matrix(M) + else: + M_coo = M + + edge_sources = M_coo.row if M_coo.row.dtype == np.int64 else M_coo.row.astype(np.int64) + edge_targets = M_coo.col if M_coo.col.dtype == np.int64 else M_coo.col.astype(np.int64) + edge_costs = M_coo.data if M_coo.data.dtype == np.float64 else M_coo.data.astype(np.float64) + n1, n2 = M_coo.shape + elif isinstance(M, tuple) and len(M) == 3: + edge_sources = np.asarray(M[0], dtype=np.int64) + edge_targets = np.asarray(M[1], dtype=np.int64) + edge_costs = np.asarray(M[2], dtype=np.float64) + n1 = int(edge_sources.max() + 1) + n2 = int(edge_targets.max() + 1) + else: + raise ValueError("When sparse=True, M must be a scipy sparse matrix or a tuple (row, col, data)") + + a, b = list_to_array(a, b) + else: + a, b, M = list_to_array(a, b, M) + + nx = get_backend(a, b) if len(a) != 0: type_as = a elif len(b) != 0: type_as = b else: - type_as = M + type_as = a - # if empty array given then use uniform distributions if len(a) == 0: - a = nx.ones((M.shape[0],), type_as=type_as) / M.shape[0] + a = nx.ones((n1,), type_as=type_as) / n1 if n1 else nx.ones((M.shape[0],), type_as=type_as) / M.shape[0] if len(b) == 0: - b = nx.ones((M.shape[1],), type_as=type_as) / M.shape[1] + b = nx.ones((n2,), type_as=type_as) / n2 if n2 else nx.ones((M.shape[1],), type_as=type_as) / M.shape[1] - # convert to numpy - M, a, b = nx.to_numpy(M, a, b) + if sparse: + a, b = nx.to_numpy(a, b) + else: + M, a, b = nx.to_numpy(M, a, b) + M = np.asarray(M, dtype=np.float64, order="C") - # ensure float64 a = np.asarray(a, dtype=np.float64) b = np.asarray(b, dtype=np.float64) - M = np.asarray(M, dtype=np.float64, order="C") - # if empty array given then use uniform distributions - if len(a) == 0: - a = np.ones((M.shape[0],), dtype=np.float64) / M.shape[0] - if len(b) == 0: - b = np.ones((M.shape[1],), dtype=np.float64) / M.shape[1] + + if n1 is None: + n1, n2 = M.shape assert ( - a.shape[0] == M.shape[0] and b.shape[0] == M.shape[1] + a.shape[0] == n1 and b.shape[0] == n2 ), "Dimension mismatch, check dimensions of M with a and b" # ensure that same mass @@ -321,13 +357,26 @@ def emd( numThreads = check_number_threads(numThreads) - G, cost, u, v, result_code = emd_c(a, b, M, numItermax, numThreads) + if edge_sources is not None: + flow_sources, flow_targets, flow_values, cost, u, v, result_code = emd_c_sparse( + a, b, edge_sources, edge_targets, edge_costs, numItermax + ) + if return_matrix: + G = np.zeros((len(a), len(b)), dtype=np.float64) + G[flow_sources, flow_targets] = flow_values + else: + G = None + else: + G, cost, u, v, result_code = emd_c(a, b, M, numItermax, numThreads) if center_dual: u, v = center_ot_dual(u, v, a, b) if np.any(~asel) or np.any(~bsel): - u, v = estimate_dual_null_weights(u, v, a, b, M) + if edge_sources is not None: + u, v = center_ot_dual(u, v, a, b) + else: + u, v = estimate_dual_null_weights(u, v, a, b, M) result_code_string = check_result(result_code) if not nx.is_floating_point(type_as): @@ -338,15 +387,29 @@ def emd( "histogram consists of floating point elements.", stacklevel=2, ) + if log: - log = {} - log["cost"] = cost - log["u"] = nx.from_numpy(u, type_as=type_as) - log["v"] = nx.from_numpy(v, type_as=type_as) - log["warning"] = result_code_string - log["result_code"] = result_code - return nx.from_numpy(G, type_as=type_as), log - return nx.from_numpy(G, type_as=type_as) + log_dict = {} + log_dict["cost"] = cost + log_dict["u"] = nx.from_numpy(u, type_as=type_as) + log_dict["v"] = nx.from_numpy(v, type_as=type_as) + log_dict["warning"] = result_code_string + log_dict["result_code"] = result_code + + if edge_sources is not None and not return_matrix: + log_dict["flow_sources"] = flow_sources + log_dict["flow_targets"] = flow_targets + log_dict["flow_values"] = flow_values + + if G is not None: + return nx.from_numpy(G, type_as=type_as), log_dict + else: + return None, log_dict + + if G is not None: + return nx.from_numpy(G, type_as=type_as) + else: + raise ValueError("Cannot return matrix when return_matrix=False and sparse=True without log=True") def emd2( @@ -356,10 +419,12 @@ def emd2( processes=1, numItermax=100000, log=False, - return_matrix=False, + center_dual=True, numThreads=1, check_marginals=True, + sparse=False, + return_matrix=False ): r"""Solves the Earth Movers distance problem and returns the loss @@ -420,6 +485,12 @@ def emd2( check_marginals: bool, optional (default=True) If True, checks that the marginals mass are equal. If False, skips the check. + sparse: bool, optional (default=False) + If True, uses the sparse solver that only stores edges with finite costs. + This is memory-efficient when M has many infinite or forbidden edges. + When sparse=True, M should be a scipy.sparse matrix (coo, csr, or csc format) + or a tuple (row_indices, col_indices, costs) representing the edge list. + Edges not included are treated as having infinite cost (forbidden). Returns @@ -460,34 +531,78 @@ def emd2( ot.optim.cg : General regularized OT """ - a, b, M = list_to_array(a, b, M) - nx = get_backend(M, a, b) + edge_sources = None + edge_targets = None + edge_costs = None + n1, n2 = None, None + + if sparse: + if sp.issparse(M): + t0 = time.perf_counter() + if not isinstance(M, sp.coo_matrix): + M_coo = sp.coo_matrix(M) + else: + M_coo = M + t1 = time.perf_counter() + + edge_sources = M_coo.row if M_coo.row.dtype == np.int64 else M_coo.row.astype(np.int64) + edge_targets = M_coo.col if M_coo.col.dtype == np.int64 else M_coo.col.astype(np.int64) + edge_costs = M_coo.data if M_coo.data.dtype == np.float64 else M_coo.data.astype(np.float64) + t2 = time.perf_counter() + print(f"[PY SPARSE] COO conversion: {(t1-t0)*1000:.3f} ms, array copies: {(t2-t1)*1000:.3f} ms") + n1, n2 = M_coo.shape + elif isinstance(M, tuple) and len(M) == 3: + edge_sources = np.asarray(M[0], dtype=np.int64) + edge_targets = np.asarray(M[1], dtype=np.int64) + edge_costs = np.asarray(M[2], dtype=np.float64) + n1 = int(edge_sources.max() + 1) + n2 = int(edge_targets.max() + 1) + else: + raise ValueError( + "When sparse=True, M must be a scipy sparse matrix or a tuple (row, col, data)" + ) + + a, b = list_to_array(a, b) + else: + a, b, M = list_to_array(a, b, M) + + nx = get_backend(a, b) if len(a) != 0: type_as = a elif len(b) != 0: type_as = b else: - type_as = M + type_as = a # Can't use M for sparse case # if empty array given then use uniform distributions if len(a) == 0: - a = nx.ones((M.shape[0],), type_as=type_as) / M.shape[0] + a = nx.ones((n1,), type_as=type_as) / n1 if n1 else nx.ones((M.shape[0],), type_as=type_as) / M.shape[0] if len(b) == 0: - b = nx.ones((M.shape[1],), type_as=type_as) / M.shape[1] + b = nx.ones((n2,), type_as=type_as) / n2 if n2 else nx.ones((M.shape[1],), type_as=type_as) / M.shape[1] + + a0, b0 = a, b + M0 = None if sparse else M - # store original tensors - a0, b0, M0 = a, b, M + if sparse: + edge_costs_original = nx.from_numpy(edge_costs, type_as=type_as) + else: + edge_costs_original = None - # convert to numpy - M, a, b = nx.to_numpy(M, a, b) + if sparse: + a, b = nx.to_numpy(a, b) + else: + M, a, b = nx.to_numpy(M, a, b) + M = np.asarray(M, dtype=np.float64, order="C") a = np.asarray(a, dtype=np.float64) b = np.asarray(b, dtype=np.float64) - M = np.asarray(M, dtype=np.float64, order="C") + + if n1 is None: + n1, n2 = M.shape assert ( - a.shape[0] == M.shape[0] and b.shape[0] == M.shape[1] + a.shape[0] == n1 and b.shape[0] == n2 ), "Dimension mismatch, check dimensions of M with a and b" # ensure that same mass @@ -509,13 +624,36 @@ def emd2( def f(b): bsel = b != 0 - G, cost, u, v, result_code = emd_c(a, b, M, numItermax, numThreads) + if edge_sources is not None: + flow_sources, flow_targets, flow_values, cost, u, v, result_code = emd_c_sparse( + a, b, edge_sources, edge_targets, edge_costs, numItermax + ) + + edge_to_idx = {(edge_sources[k], edge_targets[k]): k for k in range(len(edge_sources))} + + grad_edge_costs = np.zeros(len(edge_costs), dtype=np.float64) + for idx in range(len(flow_sources)): + src, tgt, flow = flow_sources[idx], flow_targets[idx], flow_values[idx] + edge_idx = edge_to_idx.get((src, tgt), -1) + if edge_idx >= 0: + grad_edge_costs[edge_idx] = flow + + if return_matrix: + G = np.zeros((len(a), len(b)), dtype=np.float64) + G[flow_sources, flow_targets] = flow_values + else: + G = None + else: + G, cost, u, v, result_code = emd_c(a, b, M, numItermax, numThreads) if center_dual: u, v = center_ot_dual(u, v, a, b) if np.any(~asel) or np.any(~bsel): - u, v = estimate_dual_null_weights(u, v, a, b, M) + if edge_sources is not None: + u, v = center_ot_dual(u, v, a, b) + else: + u, v = estimate_dual_null_weights(u, v, a, b, M) result_code_string = check_result(result_code) log = {} @@ -527,30 +665,59 @@ def f(b): "histogram consists of floating point elements.", stacklevel=2, ) - G = nx.from_numpy(G, type_as=type_as) - if return_matrix: - log["G"] = G + + if G is not None: + G = nx.from_numpy(G, type_as=type_as) + if return_matrix: + log["G"] = G log["u"] = nx.from_numpy(u, type_as=type_as) log["v"] = nx.from_numpy(v, type_as=type_as) log["warning"] = result_code_string log["result_code"] = result_code - cost = nx.set_gradients( - nx.from_numpy(cost, type_as=type_as), - (a0, b0, M0), - (log["u"] - nx.mean(log["u"]), log["v"] - nx.mean(log["v"]), G), - ) + + if edge_sources is not None: + cost = nx.set_gradients( + nx.from_numpy(cost, type_as=type_as), + (a0, b0, edge_costs_original), + (log["u"] - nx.mean(log["u"]), log["v"] - nx.mean(log["v"]), nx.from_numpy(grad_edge_costs, type_as=type_as)), + ) + else: + cost = nx.set_gradients( + nx.from_numpy(cost, type_as=type_as), + (a0, b0, M0), + (log["u"] - nx.mean(log["u"]), log["v"] - nx.mean(log["v"]), G), + ) return [cost, log] else: def f(b): bsel = b != 0 - G, cost, u, v, result_code = emd_c(a, b, M, numItermax, numThreads) + + if edge_sources is not None: + flow_sources, flow_targets, flow_values, cost, u, v, result_code = emd_c_sparse( + a, b, edge_sources, edge_targets, edge_costs, numItermax + ) + + edge_to_idx = {(edge_sources[k], edge_targets[k]): k for k in range(len(edge_sources))} + grad_edge_costs = np.zeros(len(edge_costs), dtype=np.float64) + for idx in range(len(flow_sources)): + src, tgt, flow = flow_sources[idx], flow_targets[idx], flow_values[idx] + edge_idx = edge_to_idx.get((src, tgt), -1) + if edge_idx >= 0: + grad_edge_costs[edge_idx] = flow + + G = None + else: + G, cost, u, v, result_code = emd_c(a, b, M, numItermax, numThreads) if center_dual: u, v = center_ot_dual(u, v, a, b) if np.any(~asel) or np.any(~bsel): - u, v = estimate_dual_null_weights(u, v, a, b, M) + if edge_sources is not None: + u, v = center_ot_dual(u, v, a, b) + else: + u, v = estimate_dual_null_weights(u, v, a, b, M) if not nx.is_floating_point(type_as): warnings.warn( @@ -560,16 +727,29 @@ def f(b): "histogram consists of floating point elements.", stacklevel=2, ) - G = nx.from_numpy(G, type_as=type_as) - cost = nx.set_gradients( - nx.from_numpy(cost, type_as=type_as), - (a0, b0, M0), - ( - nx.from_numpy(u - np.mean(u), type_as=type_as), - nx.from_numpy(v - np.mean(v), type_as=type_as), - G, - ), - ) + + if edge_sources is not None: + # Sparse: gradient w.r.t. edge_costs (no need to convert G) + cost = nx.set_gradients( + nx.from_numpy(cost, type_as=type_as), + (a0, b0, edge_costs_original), + ( + nx.from_numpy(u - np.mean(u), type_as=type_as), + nx.from_numpy(v - np.mean(v), type_as=type_as), + nx.from_numpy(grad_edge_costs, type_as=type_as), + ), + ) + else: + G = nx.from_numpy(G, type_as=type_as) + cost = nx.set_gradients( + nx.from_numpy(cost, type_as=type_as), + (a0, b0, M0), + ( + nx.from_numpy(u - np.mean(u), type_as=type_as), + nx.from_numpy(v - np.mean(v), type_as=type_as), + G, + ), + ) check_result(result_code) return cost diff --git a/ot/lp/emd_wrap.pyx b/ot/lp/emd_wrap.pyx index 53df54fc3..b4f603605 100644 --- a/ot/lp/emd_wrap.pyx +++ b/ot/lp/emd_wrap.pyx @@ -22,6 +22,7 @@ import warnings cdef extern from "EMD.h": int EMD_wrap(int n1,int n2, double *X, double *Y,double *D, double *G, double* alpha, double* beta, double *cost, uint64_t maxIter) nogil int EMD_wrap_omp(int n1,int n2, double *X, double *Y,double *D, double *G, double* alpha, double* beta, double *cost, uint64_t maxIter, int numThreads) nogil + int EMD_wrap_sparse(int n1, int n2, double *X, double *Y, uint64_t n_edges, long long *edge_sources, long long *edge_targets, double *edge_costs, long long *flow_sources_out, long long *flow_targets_out, double *flow_values_out, uint64_t *n_flows_out, double *alpha, double *beta, double *cost, uint64_t maxIter) nogil cdef enum ProblemType: INFEASIBLE, OPTIMAL, UNBOUNDED, MAX_ITER_REACHED @@ -206,3 +207,78 @@ def emd_1d_sorted(np.ndarray[double, ndim=1, mode="c"] u_weights, cur_idx += 1 cur_idx += 1 return G[:cur_idx], indices[:cur_idx], cost + +@cython.boundscheck(False) +@cython.wraparound(False) +def emd_c_sparse(np.ndarray[double, ndim=1, mode="c"] a, + np.ndarray[double, ndim=1, mode="c"] b, + np.ndarray[long long, ndim=1, mode="c"] edge_sources, + np.ndarray[long long, ndim=1, mode="c"] edge_targets, + np.ndarray[double, ndim=1, mode="c"] edge_costs, + uint64_t max_iter): + """ + Sparse EMD solver - only considers edges in edge_sources/edge_targets + + Parameters + ---------- + a : (n1,) array + Source histogram + b : (n2,) array + Target histogram + edge_sources : (k,) array, int64 + Source indices for each edge + edge_targets : (k,) array, int64 + Target indices for each edge + edge_costs : (k,) array, float64 + Cost for each edge + max_iter : uint64_t + Maximum iterations + + Returns + ------- + flow_sources : (n_flows,) array, int64 + Source indices of non-zero flows + flow_targets : (n_flows,) array, int64 + Target indices of non-zero flows + flow_values : (n_flows,) array, float64 + Flow values + cost : float + Total cost + alpha : (n1,) array + Dual variables for sources + beta : (n2,) array + Dual variables for targets + result_code : int + Result status + """ + cdef int n1 = a.shape[0] + cdef int n2 = b.shape[0] + cdef uint64_t n_edges = edge_sources.shape[0] + cdef uint64_t n_flows_out = 0 + cdef int result_code = 0 + cdef double cost = 0 + + # Allocate output arrays (max size = n_edges) + cdef np.ndarray[long long, ndim=1, mode="c"] flow_sources = np.zeros(n_edges, dtype=np.int64) + cdef np.ndarray[long long, ndim=1, mode="c"] flow_targets = np.zeros(n_edges, dtype=np.int64) + cdef np.ndarray[double, ndim=1, mode="c"] flow_values = np.zeros(n_edges, dtype=np.float64) + cdef np.ndarray[double, ndim=1, mode="c"] alpha = np.zeros(n1) + cdef np.ndarray[double, ndim=1, mode="c"] beta = np.zeros(n2) + + with nogil: + result_code = EMD_wrap_sparse( + n1, n2, + a.data, b.data, + n_edges, + edge_sources.data, edge_targets.data, edge_costs.data, + flow_sources.data, flow_targets.data, flow_values.data, + &n_flows_out, + alpha.data, beta.data, &cost, max_iter + ) + + # Trim to actual number of flows + flow_sources = flow_sources[:n_flows_out] + flow_targets = flow_targets[:n_flows_out] + flow_values = flow_values[:n_flows_out] + + return flow_sources, flow_targets, flow_values, cost, alpha, beta, result_code \ No newline at end of file diff --git a/ot/lp/sparse_bipartitegraph.h b/ot/lp/sparse_bipartitegraph.h new file mode 100644 index 000000000..7ba13b41a --- /dev/null +++ b/ot/lp/sparse_bipartitegraph.h @@ -0,0 +1,281 @@ +/* -*- mode: C++; indent-tabs-mode: nil; -*- + * + * Sparse bipartite graph for optimal transport + * Only stores edges that are explicitly added (not all n1×n2 edges) + * + * Uses CSR (Compressed Sparse Row) format for better cache locality and performance + * - Binary search for arc lookup: O(log k) where k = avg edges per node + * - Compact memory layout for better cache performance + * - Requires edges to be provided in sorted order during construction + */ + +#pragma once + +#include "core.h" +#include +#include +#include + +namespace lemon { + + class SparseBipartiteDigraphBase { + public: + + typedef SparseBipartiteDigraphBase Digraph; + typedef int Node; + typedef int64_t Arc; + + protected: + + int _node_num; + int64_t _arc_num; + int _n1, _n2; + + std::vector _arc_sources; // _arc_sources[arc_id] = source node + std::vector _arc_targets; // _arc_targets[arc_id] = target node + + // CSR format + // _row_ptr[i] = start index in _col_indices for source node i + // _row_ptr[i+1] - _row_ptr[i] = number of outgoing edges from node i + std::vector _row_ptr; + std::vector _col_indices; + std::vector _arc_ids; + + mutable std::vector> _in_arcs; // _in_arcs[node] = incoming arc IDs + mutable bool _in_arcs_built; + + SparseBipartiteDigraphBase() : _node_num(0), _arc_num(0), _n1(0), _n2(0), _in_arcs_built(false) {} + + void construct(int n1, int n2) { + _node_num = n1 + n2; + _n1 = n1; + _n2 = n2; + _arc_num = 0; + _arc_sources.clear(); + _arc_targets.clear(); + _row_ptr.clear(); + _col_indices.clear(); + _arc_ids.clear(); + _in_arcs.clear(); + _in_arcs_built = false; + } + + void build_in_arcs() const { + if (_in_arcs_built) return; + + _in_arcs.resize(_node_num); + + for (Arc a = 0; a < _arc_num; ++a) { + Node tgt = _arc_targets[a]; + _in_arcs[tgt].push_back(a); + } + + _in_arcs_built = true; + } + + public: + + Node operator()(int ix) const { return Node(ix); } + static int index(const Node& node) { return node; } + + void buildFromEdges(const std::vector>& edges) { + _arc_num = edges.size(); + + if (_arc_num == 0) { + _row_ptr.assign(_n1 + 1, 0); + return; + } + + // Create indexed edges: (source, target, original_arc_id) + std::vector> indexed_edges; + indexed_edges.reserve(_arc_num); + for (Arc i = 0; i < _arc_num; ++i) { + indexed_edges.emplace_back(edges[i].first, edges[i].second, i); + } + + // Sort by source node, then by target node CSR requirement + std::sort(indexed_edges.begin(), indexed_edges.end(), + [](const auto& a, const auto& b) { + if (std::get<0>(a) != std::get<0>(b)) + return std::get<0>(a) < std::get<0>(b); + return std::get<1>(a) < std::get<1>(b); + }); + + _arc_sources.resize(_arc_num); + _arc_targets.resize(_arc_num); + _col_indices.resize(_arc_num); + _arc_ids.resize(_arc_num); + _row_ptr.resize(_n1 + 1); + + _row_ptr[0] = 0; + int current_row = 0; + + for (int64_t i = 0; i < _arc_num; ++i) { + Node src = std::get<0>(indexed_edges[i]); + Node tgt = std::get<1>(indexed_edges[i]); + Arc orig_arc_id = std::get<2>(indexed_edges[i]); + + // Fill out row_ptr for rows with no outgoing edges + while (current_row < src) { + _row_ptr[++current_row] = i; + } + + _arc_sources[orig_arc_id] = src; + _arc_targets[orig_arc_id] = tgt; + _col_indices[i] = tgt; + _arc_ids[i] = orig_arc_id; + } + + // Fill remaining row_ptr entries + while (current_row < _n1) { + _row_ptr[++current_row] = _arc_num; + } + + _in_arcs_built = false; + } + + // Find arc from s to t using binary search (returns -1 if not found) + Arc arc(const Node& s, const Node& t) const { + if (s < 0 || s >= _n1 || t < _n1 || t >= _node_num) { + return Arc(-1); + } + + int64_t start = _row_ptr[s]; + int64_t end = _row_ptr[s + 1]; + + // Binary search for target t in col_indices[start:end] + auto it = std::lower_bound( + _col_indices.begin() + start, + _col_indices.begin() + end, + t + ); + + if (it != _col_indices.begin() + end && *it == t) { + int64_t pos = it - _col_indices.begin(); + return _arc_ids[pos]; + } + + return Arc(-1); + } + + int nodeNum() const { return _node_num; } + int64_t arcNum() const { return _arc_num; } + + int maxNodeId() const { return _node_num - 1; } + int64_t maxArcId() const { return _arc_num - 1; } + + Node source(Arc arc) const { + return (arc >= 0 && arc < _arc_num) ? _arc_sources[arc] : Node(-1); + } + + Node target(Arc arc) const { + return (arc >= 0 && arc < _arc_num) ? _arc_targets[arc] : Node(-1); + } + + static int id(Node node) { return node; } + static int64_t id(Arc arc) { return arc; } + + static Node nodeFromId(int id) { return Node(id); } + static Arc arcFromId(int64_t id) { return Arc(id); } + + Arc findArc(Node s, Node t, Arc prev = -1) const { + return prev == -1 ? arc(s, t) : Arc(-1); + } + + void first(Node& node) const { + node = _node_num - 1; + } + + static void next(Node& node) { + --node; + } + + void first(Arc& arc) const { + arc = _arc_num - 1; + } + + static void next(Arc& arc) { + --arc; + } + + void firstOut(Arc& arc, const Node& node) const { + if (node < 0 || node >= _n1) { + arc = -1; + return; + } + + int64_t start = _row_ptr[node]; + int64_t end = _row_ptr[node + 1]; + + arc = (start < end) ? _arc_ids[start] : Arc(-1); + } + + void nextOut(Arc& arc) const { + if (arc < 0) return; + + Node src = _arc_sources[arc]; + int64_t start = _row_ptr[src]; + int64_t end = _row_ptr[src + 1]; + + for (int64_t i = start; i < end; ++i) { + if (_arc_ids[i] == arc) { + arc = (i + 1 < end) ? _arc_ids[i + 1] : Arc(-1); + return; + } + } + arc = -1; + } + + void firstIn(Arc& arc, const Node& node) const { + build_in_arcs(); // Lazy build on first call + + if (node < 0 || node >= _node_num || node < _n1) { + arc = -1; // Invalid node or source nodes have no incoming arcs + return; + } + + const std::vector& in = _in_arcs[node]; + arc = in.empty() ? Arc(-1) : in[0]; + } + + void nextIn(Arc& arc) const { + if (arc < 0) return; + + Node tgt = _arc_targets[arc]; + const std::vector& in = _in_arcs[tgt]; + + // Find current arc in the list and return next one + for (size_t i = 0; i < in.size(); ++i) { + if (in[i] == arc) { + arc = (i + 1 < in.size()) ? in[i + 1] : Arc(-1); + return; + } + } + arc = -1; + } + }; + + /// Sparse bipartite digraph - only stores edges that are explicitly added + class SparseBipartiteDigraph : public SparseBipartiteDigraphBase { + typedef SparseBipartiteDigraphBase Parent; + + public: + + SparseBipartiteDigraph() { construct(0, 0); } + + SparseBipartiteDigraph(int n1, int n2) { construct(n1, n2); } + + Node operator()(int ix) const { return Parent::operator()(ix); } + static int index(const Node& node) { return Parent::index(node); } + + void buildFromEdges(const std::vector>& edges) { + Parent::buildFromEdges(edges); + } + + Arc arc(Node s, Node t) const { return Parent::arc(s, t); } + + int nodeNum() const { return Parent::nodeNum(); } + int64_t arcNum() const { return Parent::arcNum(); } + }; + +} //namespace lemon diff --git a/setup.py b/setup.py index acbe5aed9..c8cefb729 100644 --- a/setup.py +++ b/setup.py @@ -50,7 +50,19 @@ link_args += flags if sys.platform.startswith("darwin"): - compile_args.append("-stdlib=libc++") + # Only add -stdlib=libc++ for Clang, not GCC + # GCC uses libstdc++ by default and doesn't recognize -stdlib flag + import subprocess + try: + # Check if using clang + compiler = os.environ.get('CXX', 'c++') + version_output = subprocess.check_output([compiler, '--version'], stderr=subprocess.STDOUT).decode() + if 'clang' in version_output.lower(): + compile_args.append("-stdlib=libc++") + except Exception: + # If we can't determine, don't add the flag (safer for GCC) + pass + sdk_path = subprocess.check_output(["xcrun", "--show-sdk-path"]) os.environ["CFLAGS"] = '-isysroot "{}"'.format(sdk_path.rstrip().decode("utf-8")) diff --git a/test/test_ot.py b/test/test_ot.py index e8217d54d..e4c55f6f4 100644 --- a/test/test_ot.py +++ b/test/test_ot.py @@ -12,7 +12,7 @@ import ot from ot.datasets import make_1D_gauss as gauss from ot.backend import torch, tf, get_backend - +from scipy.sparse import coo_matrix def test_emd_dimension_and_mass_mismatch(): # test emd and emd2 for dimension mismatch @@ -914,6 +914,153 @@ def test_dual_variables(): assert constraint_violation.max() < 1e-8 +def test_emd_sparse_vs_dense(): + + n_source = 100 + n_target = 100 + k = 10 + + rng = np.random.RandomState(42) + + x_source = rng.randn(n_source, 2) + x_target = rng.randn(n_target, 2) + 0.5 + + a = ot.utils.unif(n_source) + b = ot.utils.unif(n_target) + + C = ot.dist(x_source, x_target) + + rows = [] + cols = [] + data = [] + + for i in range(n_source): + distances = C[i, :] + nearest_k = np.argpartition(distances, k)[:k] + for j in nearest_k: + rows.append(i) + cols.append(j) + data.append(C[i, j]) + + C_knn = coo_matrix((data, (rows, cols)), shape=(n_source, n_target)) + + large_cost = 1e8 + C_dense_infty = np.full((n_source, n_target), large_cost) + C_knn_array = C_knn.toarray() + C_dense_infty[C_knn_array > 0] = C_knn_array[C_knn_array > 0] + + G_dense_initial = ot.emd(a, b, C_dense_infty) + eps = 1e-9 + active_mask = G_dense_initial > eps + knn_mask = C_knn_array > 0 + extra_edges_mask = active_mask & ~knn_mask + + rows_aug = [] + cols_aug = [] + data_aug = [] + + knn_rows, knn_cols = np.where(knn_mask) + for i, j in zip(knn_rows, knn_cols): + rows_aug.append(i) + cols_aug.append(j) + data_aug.append(C[i, j]) + + extra_rows, extra_cols = np.where(extra_edges_mask) + for i, j in zip(extra_rows, extra_cols): + rows_aug.append(i) + cols_aug.append(j) + data_aug.append(C[i, j]) + + C_augmented = coo_matrix((data_aug, (rows_aug, cols_aug)), shape=(n_source, n_target)) + + C_augmented_dense = np.full((n_source, n_target), large_cost) + C_augmented_array = C_augmented.toarray() + C_augmented_dense[C_augmented_array > 0] = C_augmented_array[C_augmented_array > 0] + + G_dense, log_dense = ot.emd(a, b, C_augmented_dense, log=True) + G_sparse, log_sparse = ot.emd(a, b, C_augmented, log=True, sparse=True, return_matrix=True) + + cost_dense = log_dense['cost'] + cost_sparse = log_sparse['cost'] + + np.testing.assert_allclose(cost_dense, cost_sparse, rtol=1e-5, atol=1e-7) + + np.testing.assert_allclose(a, G_dense.sum(1), rtol=1e-5, atol=1e-7) + np.testing.assert_allclose(b, G_dense.sum(0), rtol=1e-5, atol=1e-7) + np.testing.assert_allclose(a, G_sparse.sum(1), rtol=1e-5, atol=1e-7) + np.testing.assert_allclose(b, G_sparse.sum(0), rtol=1e-5, atol=1e-7) + + +def test_emd2_sparse_vs_dense(): + + n_source = 100 + n_target = 100 + k = 10 + + rng = np.random.RandomState(42) + + x_source = rng.randn(n_source, 2) + x_target = rng.randn(n_target, 2) + 0.5 + + a = ot.utils.unif(n_source) + b = ot.utils.unif(n_target) + + C = ot.dist(x_source, x_target) + + rows = [] + cols = [] + data = [] + + for i in range(n_source): + distances = C[i, :] + nearest_k = np.argpartition(distances, k)[:k] + for j in nearest_k: + rows.append(i) + cols.append(j) + data.append(C[i, j]) + + C_knn = coo_matrix((data, (rows, cols)), shape=(n_source, n_target)) + + large_cost = 1e8 + C_dense_infty = np.full((n_source, n_target), large_cost) + C_knn_array = C_knn.toarray() + C_dense_infty[C_knn_array > 0] = C_knn_array[C_knn_array > 0] + + G_dense_initial = ot.emd(a, b, C_dense_infty) + + eps = 1e-9 + active_mask = G_dense_initial > eps + knn_mask = C_knn_array > 0 + extra_edges_mask = active_mask & ~knn_mask + + rows_aug = [] + cols_aug = [] + data_aug = [] + + knn_rows, knn_cols = np.where(knn_mask) + for i, j in zip(knn_rows, knn_cols): + rows_aug.append(i) + cols_aug.append(j) + data_aug.append(C[i, j]) + + extra_rows, extra_cols = np.where(extra_edges_mask) + for i, j in zip(extra_rows, extra_cols): + rows_aug.append(i) + cols_aug.append(j) + data_aug.append(C[i, j]) + + C_augmented = coo_matrix((data_aug, (rows_aug, cols_aug)), shape=(n_source, n_target)) + + C_augmented_dense = np.full((n_source, n_target), large_cost) + C_augmented_array = C_augmented.toarray() + C_augmented_dense[C_augmented_array > 0] = C_augmented_array[C_augmented_array > 0] + + cost_dense = ot.emd2(a, b, C_augmented_dense) + cost_sparse = ot.emd2(a, b, C_augmented, sparse=True) + + np.testing.assert_allclose(cost_dense, cost_sparse, rtol=1e-5, atol=1e-7) + + def check_duality_gap(a, b, M, G, u, v, cost): cost_dual = np.vdot(a, u) + np.vdot(b, v) # Check that dual and primal cost are equal From 0eee6f1e14a4c78039f453158bce08a42f4f9098 Mon Sep 17 00:00:00 2001 From: nathanneike Date: Tue, 28 Oct 2025 16:22:31 +0100 Subject: [PATCH 2/5] [WIP] Add sparse EMD solver with unit tests This PR implements a sparse bipartite graph EMD solver for memory-efficient optimal transport when the cost matrix has many infinite or forbidden edges. Changes: - Implement sparse bipartite graph EMD solver in C++ - Add Python bindings for sparse solver (emd_wrap.pyx, _network_simplex.py) - Add unit tests to verify sparse and dense solvers produce identical results - Tests use augmented k-NN approach to ensure fair comparison Tests verify correctness: * test_emd_sparse_vs_dense() - verifies identical costs and marginal constraints * test_emd2_sparse_vs_dense() - verifies cost-only version Status: WIP - seeking feedback on implementation approach TODO: Add example script and documentation --- ot/lp/_network_simplex.py | 112 +++++++++++++++++++++++++++++--------- test/test_ot.py | 29 +++++++--- 2 files changed, 107 insertions(+), 34 deletions(-) diff --git a/ot/lp/_network_simplex.py b/ot/lp/_network_simplex.py index 3ce63a874..35b185746 100644 --- a/ot/lp/_network_simplex.py +++ b/ot/lp/_network_simplex.py @@ -294,9 +294,17 @@ def emd( else: M_coo = M - edge_sources = M_coo.row if M_coo.row.dtype == np.int64 else M_coo.row.astype(np.int64) - edge_targets = M_coo.col if M_coo.col.dtype == np.int64 else M_coo.col.astype(np.int64) - edge_costs = M_coo.data if M_coo.data.dtype == np.float64 else M_coo.data.astype(np.float64) + edge_sources = ( + M_coo.row if M_coo.row.dtype == np.int64 else M_coo.row.astype(np.int64) + ) + edge_targets = ( + M_coo.col if M_coo.col.dtype == np.int64 else M_coo.col.astype(np.int64) + ) + edge_costs = ( + M_coo.data + if M_coo.data.dtype == np.float64 + else M_coo.data.astype(np.float64) + ) n1, n2 = M_coo.shape elif isinstance(M, tuple) and len(M) == 3: edge_sources = np.asarray(M[0], dtype=np.int64) @@ -305,7 +313,9 @@ def emd( n1 = int(edge_sources.max() + 1) n2 = int(edge_targets.max() + 1) else: - raise ValueError("When sparse=True, M must be a scipy sparse matrix or a tuple (row, col, data)") + raise ValueError( + "When sparse=True, M must be a scipy sparse matrix or a tuple (row, col, data)" + ) a, b = list_to_array(a, b) else: @@ -321,9 +331,17 @@ def emd( type_as = a if len(a) == 0: - a = nx.ones((n1,), type_as=type_as) / n1 if n1 else nx.ones((M.shape[0],), type_as=type_as) / M.shape[0] + a = ( + nx.ones((n1,), type_as=type_as) / n1 + if n1 + else nx.ones((M.shape[0],), type_as=type_as) / M.shape[0] + ) if len(b) == 0: - b = nx.ones((n2,), type_as=type_as) / n2 if n2 else nx.ones((M.shape[1],), type_as=type_as) / M.shape[1] + b = ( + nx.ones((n2,), type_as=type_as) / n2 + if n2 + else nx.ones((M.shape[1],), type_as=type_as) / M.shape[1] + ) if sparse: a, b = nx.to_numpy(a, b) @@ -334,7 +352,6 @@ def emd( a = np.asarray(a, dtype=np.float64) b = np.asarray(b, dtype=np.float64) - if n1 is None: n1, n2 = M.shape @@ -409,7 +426,9 @@ def emd( if G is not None: return nx.from_numpy(G, type_as=type_as) else: - raise ValueError("Cannot return matrix when return_matrix=False and sparse=True without log=True") + raise ValueError( + "Cannot return matrix when return_matrix=False and sparse=True without log=True" + ) def emd2( @@ -419,12 +438,11 @@ def emd2( processes=1, numItermax=100000, log=False, - center_dual=True, numThreads=1, check_marginals=True, sparse=False, - return_matrix=False + return_matrix=False, ): r"""Solves the Earth Movers distance problem and returns the loss @@ -534,7 +552,7 @@ def emd2( edge_sources = None edge_targets = None edge_costs = None - n1, n2 = None, None + n1, n2 = None, None if sparse: if sp.issparse(M): @@ -545,11 +563,21 @@ def emd2( M_coo = M t1 = time.perf_counter() - edge_sources = M_coo.row if M_coo.row.dtype == np.int64 else M_coo.row.astype(np.int64) - edge_targets = M_coo.col if M_coo.col.dtype == np.int64 else M_coo.col.astype(np.int64) - edge_costs = M_coo.data if M_coo.data.dtype == np.float64 else M_coo.data.astype(np.float64) + edge_sources = ( + M_coo.row if M_coo.row.dtype == np.int64 else M_coo.row.astype(np.int64) + ) + edge_targets = ( + M_coo.col if M_coo.col.dtype == np.int64 else M_coo.col.astype(np.int64) + ) + edge_costs = ( + M_coo.data + if M_coo.data.dtype == np.float64 + else M_coo.data.astype(np.float64) + ) t2 = time.perf_counter() - print(f"[PY SPARSE] COO conversion: {(t1-t0)*1000:.3f} ms, array copies: {(t2-t1)*1000:.3f} ms") + print( + f"[PY SPARSE] COO conversion: {(t1-t0)*1000:.3f} ms, array copies: {(t2-t1)*1000:.3f} ms" + ) n1, n2 = M_coo.shape elif isinstance(M, tuple) and len(M) == 3: edge_sources = np.asarray(M[0], dtype=np.int64) @@ -577,12 +605,20 @@ def emd2( # if empty array given then use uniform distributions if len(a) == 0: - a = nx.ones((n1,), type_as=type_as) / n1 if n1 else nx.ones((M.shape[0],), type_as=type_as) / M.shape[0] + a = ( + nx.ones((n1,), type_as=type_as) / n1 + if n1 + else nx.ones((M.shape[0],), type_as=type_as) / M.shape[0] + ) if len(b) == 0: - b = nx.ones((n2,), type_as=type_as) / n2 if n2 else nx.ones((M.shape[1],), type_as=type_as) / M.shape[1] + b = ( + nx.ones((n2,), type_as=type_as) / n2 + if n2 + else nx.ones((M.shape[1],), type_as=type_as) / M.shape[1] + ) a0, b0 = a, b - M0 = None if sparse else M + M0 = None if sparse else M if sparse: edge_costs_original = nx.from_numpy(edge_costs, type_as=type_as) @@ -625,15 +661,24 @@ def f(b): bsel = b != 0 if edge_sources is not None: - flow_sources, flow_targets, flow_values, cost, u, v, result_code = emd_c_sparse( - a, b, edge_sources, edge_targets, edge_costs, numItermax + flow_sources, flow_targets, flow_values, cost, u, v, result_code = ( + emd_c_sparse( + a, b, edge_sources, edge_targets, edge_costs, numItermax + ) ) - edge_to_idx = {(edge_sources[k], edge_targets[k]): k for k in range(len(edge_sources))} + edge_to_idx = { + (edge_sources[k], edge_targets[k]): k + for k in range(len(edge_sources)) + } grad_edge_costs = np.zeros(len(edge_costs), dtype=np.float64) for idx in range(len(flow_sources)): - src, tgt, flow = flow_sources[idx], flow_targets[idx], flow_values[idx] + src, tgt, flow = ( + flow_sources[idx], + flow_targets[idx], + flow_values[idx], + ) edge_idx = edge_to_idx.get((src, tgt), -1) if edge_idx >= 0: grad_edge_costs[edge_idx] = flow @@ -679,7 +724,11 @@ def f(b): cost = nx.set_gradients( nx.from_numpy(cost, type_as=type_as), (a0, b0, edge_costs_original), - (log["u"] - nx.mean(log["u"]), log["v"] - nx.mean(log["v"]), nx.from_numpy(grad_edge_costs, type_as=type_as)), + ( + log["u"] - nx.mean(log["u"]), + log["v"] - nx.mean(log["v"]), + nx.from_numpy(grad_edge_costs, type_as=type_as), + ), ) else: cost = nx.set_gradients( @@ -694,14 +743,23 @@ def f(b): bsel = b != 0 if edge_sources is not None: - flow_sources, flow_targets, flow_values, cost, u, v, result_code = emd_c_sparse( - a, b, edge_sources, edge_targets, edge_costs, numItermax + flow_sources, flow_targets, flow_values, cost, u, v, result_code = ( + emd_c_sparse( + a, b, edge_sources, edge_targets, edge_costs, numItermax + ) ) - edge_to_idx = {(edge_sources[k], edge_targets[k]): k for k in range(len(edge_sources))} + edge_to_idx = { + (edge_sources[k], edge_targets[k]): k + for k in range(len(edge_sources)) + } grad_edge_costs = np.zeros(len(edge_costs), dtype=np.float64) for idx in range(len(flow_sources)): - src, tgt, flow = flow_sources[idx], flow_targets[idx], flow_values[idx] + src, tgt, flow = ( + flow_sources[idx], + flow_targets[idx], + flow_values[idx], + ) edge_idx = edge_to_idx.get((src, tgt), -1) if edge_idx >= 0: grad_edge_costs[edge_idx] = flow diff --git a/test/test_ot.py b/test/test_ot.py index e4c55f6f4..ec12f63c8 100644 --- a/test/test_ot.py +++ b/test/test_ot.py @@ -14,6 +14,7 @@ from ot.backend import torch, tf, get_backend from scipy.sparse import coo_matrix + def test_emd_dimension_and_mass_mismatch(): # test emd and emd2 for dimension mismatch n_samples = 100 @@ -915,10 +916,14 @@ def test_dual_variables(): def test_emd_sparse_vs_dense(): + """Test that sparse and dense EMD solvers produce identical results. + Uses augmented k-NN graph approach: first solves with dense solver to + identify needed edges, then compares both solvers on the same graph. + """ n_source = 100 n_target = 100 - k = 10 + k = 10 rng = np.random.RandomState(42) @@ -971,17 +976,21 @@ def test_emd_sparse_vs_dense(): cols_aug.append(j) data_aug.append(C[i, j]) - C_augmented = coo_matrix((data_aug, (rows_aug, cols_aug)), shape=(n_source, n_target)) + C_augmented = coo_matrix( + (data_aug, (rows_aug, cols_aug)), shape=(n_source, n_target) + ) C_augmented_dense = np.full((n_source, n_target), large_cost) C_augmented_array = C_augmented.toarray() C_augmented_dense[C_augmented_array > 0] = C_augmented_array[C_augmented_array > 0] G_dense, log_dense = ot.emd(a, b, C_augmented_dense, log=True) - G_sparse, log_sparse = ot.emd(a, b, C_augmented, log=True, sparse=True, return_matrix=True) + G_sparse, log_sparse = ot.emd( + a, b, C_augmented, log=True, sparse=True, return_matrix=True + ) - cost_dense = log_dense['cost'] - cost_sparse = log_sparse['cost'] + cost_dense = log_dense["cost"] + cost_sparse = log_sparse["cost"] np.testing.assert_allclose(cost_dense, cost_sparse, rtol=1e-5, atol=1e-7) @@ -992,10 +1001,14 @@ def test_emd_sparse_vs_dense(): def test_emd2_sparse_vs_dense(): + """Test that sparse and dense emd2 solvers produce identical results. + Uses augmented k-NN graph approach: first solves with dense solver to + identify needed edges, then compares both solvers on the same graph. + """ n_source = 100 n_target = 100 - k = 10 + k = 10 rng = np.random.RandomState(42) @@ -1049,7 +1062,9 @@ def test_emd2_sparse_vs_dense(): cols_aug.append(j) data_aug.append(C[i, j]) - C_augmented = coo_matrix((data_aug, (rows_aug, cols_aug)), shape=(n_source, n_target)) + C_augmented = coo_matrix( + (data_aug, (rows_aug, cols_aug)), shape=(n_source, n_target) + ) C_augmented_dense = np.full((n_source, n_target), large_cost) C_augmented_array = C_augmented.toarray() From 022720b295f42223101a87ff7c5a9db10b9c0267 Mon Sep 17 00:00:00 2001 From: nathanneike Date: Thu, 30 Oct 2025 11:08:06 +0100 Subject: [PATCH 3/5] Fix int64_t type compatibility for Linux, remove sparse and return matrix parameter from emd and fix linting issues --- ot/lp/_network_simplex.py | 85 +++++++++++++++++---------------------- ot/lp/emd_wrap.pyx | 16 ++++---- test/test_ot.py | 25 +++++++++--- 3 files changed, 65 insertions(+), 61 deletions(-) diff --git a/ot/lp/_network_simplex.py b/ot/lp/_network_simplex.py index 35b185746..7438bd131 100644 --- a/ot/lp/_network_simplex.py +++ b/ot/lp/_network_simplex.py @@ -12,7 +12,6 @@ import warnings import scipy.sparse as sp -import time from ..utils import list_to_array, check_number_threads from ..backend import get_backend from .emd_wrap import emd_c, emd_c_sparse, check_result @@ -174,8 +173,6 @@ def emd( center_dual=True, numThreads=1, check_marginals=True, - sparse=False, - return_matrix=False, ): r"""Solves the Earth Movers distance problem and returns the OT matrix @@ -236,22 +233,26 @@ def emd( check_marginals: bool, optional (default=True) If True, checks that the marginals mass are equal. If False, skips the check. - sparse: bool, optional (default=False) - If True, uses the sparse solver that only stores edges with finite costs. - When sparse=True, M should be a scipy.sparse matrix. - return_matrix: bool, optional (default=True) - If True, returns the transport matrix. If False and sparse=True, returns - sparse flow representation in log. + + .. note:: The solver automatically detects sparse format when M is provided as: + - A scipy.sparse matrix (coo, csr, csc, etc.) + - A tuple (row_indices, col_indices, costs) representing an edge list + + For sparse inputs, the solver uses a memory-efficient algorithm and returns + the flow in edge format (via log dict) instead of a full matrix. Returns ------- - gamma: array-like, shape (ns, nt) - Optimal transportation matrix for the given - parameters + gamma: array-like, shape (ns, nt), or None + Optimal transportation matrix for the given parameters. + For sparse inputs, returns None (use log=True to get flow in edge format). log: dict, optional - If input log is true, a dictionary containing the - cost and dual variables and exit status + If input log is True, a dictionary containing the cost, dual variables, + and exit status. For sparse inputs with log=True, also contains: + - 'flow_sources': source nodes of flow edges + - 'flow_targets': target nodes of flow edges + - 'flow_values': flow values on edges Examples @@ -287,7 +288,10 @@ def emd( edge_costs = None n1, n2 = None, None - if sparse: + # Auto-detect sparse format + is_sparse = sp.issparse(M) or (isinstance(M, tuple) and len(M) == 3) + + if is_sparse: if sp.issparse(M): if not isinstance(M, sp.coo_matrix): M_coo = sp.coo_matrix(M) @@ -312,10 +316,6 @@ def emd( edge_costs = np.asarray(M[2], dtype=np.float64) n1 = int(edge_sources.max() + 1) n2 = int(edge_targets.max() + 1) - else: - raise ValueError( - "When sparse=True, M must be a scipy sparse matrix or a tuple (row, col, data)" - ) a, b = list_to_array(a, b) else: @@ -343,7 +343,7 @@ def emd( else nx.ones((M.shape[1],), type_as=type_as) / M.shape[1] ) - if sparse: + if is_sparse: a, b = nx.to_numpy(a, b) else: M, a, b = nx.to_numpy(M, a, b) @@ -375,14 +375,11 @@ def emd( numThreads = check_number_threads(numThreads) if edge_sources is not None: + # Sparse solver - never build full matrix flow_sources, flow_targets, flow_values, cost, u, v, result_code = emd_c_sparse( a, b, edge_sources, edge_targets, edge_costs, numItermax ) - if return_matrix: - G = np.zeros((len(a), len(b)), dtype=np.float64) - G[flow_sources, flow_targets] = flow_values - else: - G = None + G = None else: G, cost, u, v, result_code = emd_c(a, b, M, numItermax, numThreads) @@ -413,7 +410,8 @@ def emd( log_dict["warning"] = result_code_string log_dict["result_code"] = result_code - if edge_sources is not None and not return_matrix: + if edge_sources is not None: + # For sparse, include flow in edge format log_dict["flow_sources"] = flow_sources log_dict["flow_targets"] = flow_targets log_dict["flow_values"] = flow_values @@ -427,7 +425,7 @@ def emd( return nx.from_numpy(G, type_as=type_as) else: raise ValueError( - "Cannot return matrix when return_matrix=False and sparse=True without log=True" + "For sparse inputs, log=True is required to get the flow in edge format" ) @@ -441,7 +439,6 @@ def emd2( center_dual=True, numThreads=1, check_marginals=True, - sparse=False, return_matrix=False, ): r"""Solves the Earth Movers distance problem and returns the loss @@ -503,11 +500,12 @@ def emd2( check_marginals: bool, optional (default=True) If True, checks that the marginals mass are equal. If False, skips the check. - sparse: bool, optional (default=False) - If True, uses the sparse solver that only stores edges with finite costs. - This is memory-efficient when M has many infinite or forbidden edges. - When sparse=True, M should be a scipy.sparse matrix (coo, csr, or csc format) - or a tuple (row_indices, col_indices, costs) representing the edge list. + + .. note:: The solver automatically detects sparse format when M is provided as: + - A scipy.sparse matrix (coo, csr, csc, etc.) + - A tuple (row_indices, col_indices, costs) representing an edge list + + For sparse inputs, the solver uses a memory-efficient algorithm. Edges not included are treated as having infinite cost (forbidden). @@ -554,14 +552,15 @@ def emd2( edge_costs = None n1, n2 = None, None - if sparse: + # Auto-detect sparse format + is_sparse = sp.issparse(M) or (isinstance(M, tuple) and len(M) == 3) + + if is_sparse: if sp.issparse(M): - t0 = time.perf_counter() if not isinstance(M, sp.coo_matrix): M_coo = sp.coo_matrix(M) else: M_coo = M - t1 = time.perf_counter() edge_sources = ( M_coo.row if M_coo.row.dtype == np.int64 else M_coo.row.astype(np.int64) @@ -574,10 +573,6 @@ def emd2( if M_coo.data.dtype == np.float64 else M_coo.data.astype(np.float64) ) - t2 = time.perf_counter() - print( - f"[PY SPARSE] COO conversion: {(t1-t0)*1000:.3f} ms, array copies: {(t2-t1)*1000:.3f} ms" - ) n1, n2 = M_coo.shape elif isinstance(M, tuple) and len(M) == 3: edge_sources = np.asarray(M[0], dtype=np.int64) @@ -585,10 +580,6 @@ def emd2( edge_costs = np.asarray(M[2], dtype=np.float64) n1 = int(edge_sources.max() + 1) n2 = int(edge_targets.max() + 1) - else: - raise ValueError( - "When sparse=True, M must be a scipy sparse matrix or a tuple (row, col, data)" - ) a, b = list_to_array(a, b) else: @@ -618,14 +609,14 @@ def emd2( ) a0, b0 = a, b - M0 = None if sparse else M + M0 = None if is_sparse else M - if sparse: + if is_sparse: edge_costs_original = nx.from_numpy(edge_costs, type_as=type_as) else: edge_costs_original = None - if sparse: + if is_sparse: a, b = nx.to_numpy(a, b) else: M, a, b = nx.to_numpy(M, a, b) diff --git a/ot/lp/emd_wrap.pyx b/ot/lp/emd_wrap.pyx index b4f603605..f95e47433 100644 --- a/ot/lp/emd_wrap.pyx +++ b/ot/lp/emd_wrap.pyx @@ -14,7 +14,7 @@ from ..utils import dist cimport cython cimport libc.math as math -from libc.stdint cimport uint64_t +from libc.stdint cimport uint64_t, int64_t import warnings @@ -22,7 +22,7 @@ import warnings cdef extern from "EMD.h": int EMD_wrap(int n1,int n2, double *X, double *Y,double *D, double *G, double* alpha, double* beta, double *cost, uint64_t maxIter) nogil int EMD_wrap_omp(int n1,int n2, double *X, double *Y,double *D, double *G, double* alpha, double* beta, double *cost, uint64_t maxIter, int numThreads) nogil - int EMD_wrap_sparse(int n1, int n2, double *X, double *Y, uint64_t n_edges, long long *edge_sources, long long *edge_targets, double *edge_costs, long long *flow_sources_out, long long *flow_targets_out, double *flow_values_out, uint64_t *n_flows_out, double *alpha, double *beta, double *cost, uint64_t maxIter) nogil + int EMD_wrap_sparse(int n1, int n2, double *X, double *Y, uint64_t n_edges, int64_t *edge_sources, int64_t *edge_targets, double *edge_costs, int64_t *flow_sources_out, int64_t *flow_targets_out, double *flow_values_out, uint64_t *n_flows_out, double *alpha, double *beta, double *cost, uint64_t maxIter) nogil cdef enum ProblemType: INFEASIBLE, OPTIMAL, UNBOUNDED, MAX_ITER_REACHED @@ -212,8 +212,8 @@ def emd_1d_sorted(np.ndarray[double, ndim=1, mode="c"] u_weights, @cython.wraparound(False) def emd_c_sparse(np.ndarray[double, ndim=1, mode="c"] a, np.ndarray[double, ndim=1, mode="c"] b, - np.ndarray[long long, ndim=1, mode="c"] edge_sources, - np.ndarray[long long, ndim=1, mode="c"] edge_targets, + np.ndarray[int64_t, ndim=1, mode="c"] edge_sources, + np.ndarray[int64_t, ndim=1, mode="c"] edge_targets, np.ndarray[double, ndim=1, mode="c"] edge_costs, uint64_t max_iter): """ @@ -259,8 +259,8 @@ def emd_c_sparse(np.ndarray[double, ndim=1, mode="c"] a, cdef double cost = 0 # Allocate output arrays (max size = n_edges) - cdef np.ndarray[long long, ndim=1, mode="c"] flow_sources = np.zeros(n_edges, dtype=np.int64) - cdef np.ndarray[long long, ndim=1, mode="c"] flow_targets = np.zeros(n_edges, dtype=np.int64) + cdef np.ndarray[int64_t, ndim=1, mode="c"] flow_sources = np.zeros(n_edges, dtype=np.int64) + cdef np.ndarray[int64_t, ndim=1, mode="c"] flow_targets = np.zeros(n_edges, dtype=np.int64) cdef np.ndarray[double, ndim=1, mode="c"] flow_values = np.zeros(n_edges, dtype=np.float64) cdef np.ndarray[double, ndim=1, mode="c"] alpha = np.zeros(n1) cdef np.ndarray[double, ndim=1, mode="c"] beta = np.zeros(n2) @@ -270,8 +270,8 @@ def emd_c_sparse(np.ndarray[double, ndim=1, mode="c"] a, n1, n2, a.data, b.data, n_edges, - edge_sources.data, edge_targets.data, edge_costs.data, - flow_sources.data, flow_targets.data, flow_values.data, + edge_sources.data, edge_targets.data, edge_costs.data, + flow_sources.data, flow_targets.data, flow_values.data, &n_flows_out, alpha.data, beta.data, &cost, max_iter ) diff --git a/test/test_ot.py b/test/test_ot.py index ec12f63c8..6b57c8602 100644 --- a/test/test_ot.py +++ b/test/test_ot.py @@ -985,19 +985,32 @@ def test_emd_sparse_vs_dense(): C_augmented_dense[C_augmented_array > 0] = C_augmented_array[C_augmented_array > 0] G_dense, log_dense = ot.emd(a, b, C_augmented_dense, log=True) - G_sparse, log_sparse = ot.emd( - a, b, C_augmented, log=True, sparse=True, return_matrix=True - ) + G_sparse, log_sparse = ot.emd(a, b, C_augmented, log=True) cost_dense = log_dense["cost"] cost_sparse = log_sparse["cost"] np.testing.assert_allclose(cost_dense, cost_sparse, rtol=1e-5, atol=1e-7) + # For dense, G_dense is returned; for sparse, reconstruct from flow edges np.testing.assert_allclose(a, G_dense.sum(1), rtol=1e-5, atol=1e-7) np.testing.assert_allclose(b, G_dense.sum(0), rtol=1e-5, atol=1e-7) - np.testing.assert_allclose(a, G_sparse.sum(1), rtol=1e-5, atol=1e-7) - np.testing.assert_allclose(b, G_sparse.sum(0), rtol=1e-5, atol=1e-7) + + # Reconstruct sparse matrix from flow for marginal checks + if G_sparse is None: + G_sparse_reconstructed = np.zeros((n_source, n_target)) + G_sparse_reconstructed[ + log_sparse["flow_sources"], log_sparse["flow_targets"] + ] = log_sparse["flow_values"] + np.testing.assert_allclose( + a, G_sparse_reconstructed.sum(1), rtol=1e-5, atol=1e-7 + ) + np.testing.assert_allclose( + b, G_sparse_reconstructed.sum(0), rtol=1e-5, atol=1e-7 + ) + else: + np.testing.assert_allclose(a, G_sparse.sum(1), rtol=1e-5, atol=1e-7) + np.testing.assert_allclose(b, G_sparse.sum(0), rtol=1e-5, atol=1e-7) def test_emd2_sparse_vs_dense(): @@ -1071,7 +1084,7 @@ def test_emd2_sparse_vs_dense(): C_augmented_dense[C_augmented_array > 0] = C_augmented_array[C_augmented_array > 0] cost_dense = ot.emd2(a, b, C_augmented_dense) - cost_sparse = ot.emd2(a, b, C_augmented, sparse=True) + cost_sparse = ot.emd2(a, b, C_augmented) np.testing.assert_allclose(cost_dense, cost_sparse, rtol=1e-5, atol=1e-7) From aa5f1c9431df0385ecb6da9dee58d3c56b4ad2e4 Mon Sep 17 00:00:00 2001 From: nathanneike Date: Mon, 3 Nov 2025 13:47:37 +0100 Subject: [PATCH 4/5] refactor: Clean up sparse EMD implementation - Remove tuple format support for sparse matrices (use scipy.sparse only) - Change index types from int64_t to uint64_t throughout (indices are never negative) - Refactor emd() and emd2() with clear sparse/dense code path separation - Add sparse_bipartitegraph.h to MANIFEST.in to fix build - Add test_emd_sparse_backends() to verify backend compatibility --- MANIFEST.in | 1 + ot/lp/EMD.h | 8 +- ot/lp/EMD_wrapper.cpp | 8 +- ot/lp/_network_simplex.py | 440 +++++++++++++++++--------------------- ot/lp/emd_wrap.pyx | 22 +- test/test_ot.py | 166 ++++++++++++++ 6 files changed, 387 insertions(+), 258 deletions(-) diff --git a/MANIFEST.in b/MANIFEST.in index 7c96ba026..d93298de4 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -10,4 +10,5 @@ include ot/lp/full_bipartitegraph.h include ot/lp/full_bipartitegraph_omp.h include ot/lp/network_simplex_simple.h include ot/lp/network_simplex_simple_omp.h +include ot/lp/sparse_bipartitegraph.h include ot/partial/partial_cython.pyx diff --git a/ot/lp/EMD.h b/ot/lp/EMD.h index efa839bcf..e3564a2d2 100644 --- a/ot/lp/EMD.h +++ b/ot/lp/EMD.h @@ -38,11 +38,11 @@ int EMD_wrap_sparse( double *X, double *Y, uint64_t n_edges, // Number of edges in sparse graph - int64_t *edge_sources, // Source indices for each edge (n_edges) - int64_t *edge_targets, // Target indices for each edge (n_edges) + uint64_t *edge_sources, // Source indices for each edge (n_edges) + uint64_t *edge_targets, // Target indices for each edge (n_edges) double *edge_costs, // Cost for each edge (n_edges) - int64_t *flow_sources_out, // Output: source indices of non-zero flows - int64_t *flow_targets_out, // Output: target indices of non-zero flows + uint64_t *flow_sources_out, // Output: source indices of non-zero flows + uint64_t *flow_targets_out, // Output: target indices of non-zero flows double *flow_values_out, // Output: flow values uint64_t *n_flows_out, double *alpha, // Output: dual variables for sources (n1) diff --git a/ot/lp/EMD_wrapper.cpp b/ot/lp/EMD_wrapper.cpp index 7b4b9ed6e..bd3672535 100644 --- a/ot/lp/EMD_wrapper.cpp +++ b/ot/lp/EMD_wrapper.cpp @@ -228,11 +228,11 @@ int EMD_wrap_sparse( double *X, double *Y, uint64_t n_edges, - int64_t *edge_sources, - int64_t *edge_targets, + uint64_t *edge_sources, + uint64_t *edge_targets, double *edge_costs, - int64_t *flow_sources_out, - int64_t *flow_targets_out, + uint64_t *flow_sources_out, + uint64_t *flow_targets_out, double *flow_values_out, uint64_t *n_flows_out, double *alpha, diff --git a/ot/lp/_network_simplex.py b/ot/lp/_network_simplex.py index 7438bd131..8d033679c 100644 --- a/ot/lp/_network_simplex.py +++ b/ot/lp/_network_simplex.py @@ -234,9 +234,8 @@ def emd( If True, checks that the marginals mass are equal. If False, skips the check. - .. note:: The solver automatically detects sparse format when M is provided as: - - A scipy.sparse matrix (coo, csr, csc, etc.) - - A tuple (row_indices, col_indices, costs) representing an edge list + .. note:: The solver automatically detects sparse format when M is provided as + a scipy.sparse matrix (coo, csr, csc, etc.). For sparse inputs, the solver uses a memory-efficient algorithm and returns the flow in edge format (via log dict) instead of a full matrix. @@ -288,36 +287,35 @@ def emd( edge_costs = None n1, n2 = None, None - # Auto-detect sparse format - is_sparse = sp.issparse(M) or (isinstance(M, tuple) and len(M) == 3) + # Check for sparse format + is_sparse = sp.issparse(M) if is_sparse: - if sp.issparse(M): - if not isinstance(M, sp.coo_matrix): - M_coo = sp.coo_matrix(M) - else: - M_coo = M + # Convert to COO format for edge extraction + if not isinstance(M, sp.coo_matrix): + M_coo = sp.coo_matrix(M) + else: + M_coo = M - edge_sources = ( - M_coo.row if M_coo.row.dtype == np.int64 else M_coo.row.astype(np.int64) - ) - edge_targets = ( - M_coo.col if M_coo.col.dtype == np.int64 else M_coo.col.astype(np.int64) - ) - edge_costs = ( - M_coo.data - if M_coo.data.dtype == np.float64 - else M_coo.data.astype(np.float64) - ) - n1, n2 = M_coo.shape - elif isinstance(M, tuple) and len(M) == 3: - edge_sources = np.asarray(M[0], dtype=np.int64) - edge_targets = np.asarray(M[1], dtype=np.int64) - edge_costs = np.asarray(M[2], dtype=np.float64) - n1 = int(edge_sources.max() + 1) - n2 = int(edge_targets.max() + 1) + edge_sources = ( + M_coo.row if M_coo.row.dtype == np.uint64 else M_coo.row.astype(np.uint64) + ) + edge_targets = ( + M_coo.col if M_coo.col.dtype == np.uint64 else M_coo.col.astype(np.uint64) + ) + edge_costs = ( + M_coo.data + if M_coo.data.dtype == np.float64 + else M_coo.data.astype(np.float64) + ) + n1, n2 = M_coo.shape a, b = list_to_array(a, b) + elif isinstance(M, tuple): + raise ValueError( + "Tuple format for sparse cost matrix is not supported. " + "Please use scipy.sparse format (e.g., scipy.sparse.coo_matrix, csr_matrix, etc.)." + ) else: a, b, M = list_to_array(a, b, M) @@ -330,18 +328,14 @@ def emd( else: type_as = a + # Set n1, n2 if not already set (dense case) + if n1 is None: + n1, n2 = M.shape + if len(a) == 0: - a = ( - nx.ones((n1,), type_as=type_as) / n1 - if n1 - else nx.ones((M.shape[0],), type_as=type_as) / M.shape[0] - ) + a = nx.ones((n1,), type_as=type_as) / n1 if len(b) == 0: - b = ( - nx.ones((n2,), type_as=type_as) / n2 - if n2 - else nx.ones((M.shape[1],), type_as=type_as) / M.shape[1] - ) + b = nx.ones((n2,), type_as=type_as) / n2 if is_sparse: a, b = nx.to_numpy(a, b) @@ -352,9 +346,6 @@ def emd( a = np.asarray(a, dtype=np.float64) b = np.asarray(b, dtype=np.float64) - if n1 is None: - n1, n2 = M.shape - assert ( a.shape[0] == n1 and b.shape[0] == n2 ), "Dimension mismatch, check dimensions of M with a and b" @@ -374,59 +365,77 @@ def emd( numThreads = check_number_threads(numThreads) - if edge_sources is not None: + # ============================================================================ + # SPARSE SOLVER PATH + # ============================================================================ + if is_sparse: # Sparse solver - never build full matrix flow_sources, flow_targets, flow_values, cost, u, v, result_code = emd_c_sparse( a, b, edge_sources, edge_targets, edge_costs, numItermax ) - G = None - else: - G, cost, u, v, result_code = emd_c(a, b, M, numItermax, numThreads) - - if center_dual: - u, v = center_ot_dual(u, v, a, b) - if np.any(~asel) or np.any(~bsel): - if edge_sources is not None: + # Center dual potentials + if center_dual: u, v = center_ot_dual(u, v, a, b) - else: - u, v = estimate_dual_null_weights(u, v, a, b, M) - result_code_string = check_result(result_code) - if not nx.is_floating_point(type_as): - warnings.warn( - "Input histogram consists of integer. The transport plan will be " - "casted accordingly, possibly resulting in a loss of precision. " - "If this behaviour is unwanted, please make sure your input " - "histogram consists of floating point elements.", - stacklevel=2, - ) + if np.any(~asel) or np.any(~bsel): + u, v = center_ot_dual(u, v, a, b) - if log: - log_dict = {} - log_dict["cost"] = cost - log_dict["u"] = nx.from_numpy(u, type_as=type_as) - log_dict["v"] = nx.from_numpy(v, type_as=type_as) - log_dict["warning"] = result_code_string - log_dict["result_code"] = result_code + result_code_string = check_result(result_code) - if edge_sources is not None: - # For sparse, include flow in edge format + if log: + log_dict = {} + log_dict["cost"] = cost + log_dict["u"] = nx.from_numpy(u, type_as=type_as) + log_dict["v"] = nx.from_numpy(v, type_as=type_as) + log_dict["warning"] = result_code_string + log_dict["result_code"] = result_code log_dict["flow_sources"] = flow_sources log_dict["flow_targets"] = flow_targets log_dict["flow_values"] = flow_values - if G is not None: - return nx.from_numpy(G, type_as=type_as), log_dict - else: return None, log_dict + else: + raise ValueError( + "For sparse inputs, log=True is required to get the flow in edge format" + ) - if G is not None: - return nx.from_numpy(G, type_as=type_as) + # ============================================================================ + # DENSE SOLVER PATH + # ============================================================================ else: - raise ValueError( - "For sparse inputs, log=True is required to get the flow in edge format" - ) + # Dense solver + G, cost, u, v, result_code = emd_c(a, b, M, numItermax, numThreads) + + # Center dual potentials + if center_dual: + u, v = center_ot_dual(u, v, a, b) + + if np.any(~asel) or np.any(~bsel): + u, v = estimate_dual_null_weights(u, v, a, b, M) + + result_code_string = check_result(result_code) + + if not nx.is_floating_point(type_as): + warnings.warn( + "Input histogram consists of integer. The transport plan will be " + "casted accordingly, possibly resulting in a loss of precision. " + "If this behaviour is unwanted, please make sure your input " + "histogram consists of floating point elements.", + stacklevel=2, + ) + + if log: + log_dict = {} + log_dict["cost"] = cost + log_dict["u"] = nx.from_numpy(u, type_as=type_as) + log_dict["v"] = nx.from_numpy(v, type_as=type_as) + log_dict["warning"] = result_code_string + log_dict["result_code"] = result_code + + return nx.from_numpy(G, type_as=type_as), log_dict + else: + return nx.from_numpy(G, type_as=type_as) def emd2( @@ -501,9 +510,8 @@ def emd2( If True, checks that the marginals mass are equal. If False, skips the check. - .. note:: The solver automatically detects sparse format when M is provided as: - - A scipy.sparse matrix (coo, csr, csc, etc.) - - A tuple (row_indices, col_indices, costs) representing an edge list + .. note:: The solver automatically detects sparse format when M is provided as + a scipy.sparse matrix (coo, csr, csc, etc.). For sparse inputs, the solver uses a memory-efficient algorithm. Edges not included are treated as having infinite cost (forbidden). @@ -552,36 +560,35 @@ def emd2( edge_costs = None n1, n2 = None, None - # Auto-detect sparse format - is_sparse = sp.issparse(M) or (isinstance(M, tuple) and len(M) == 3) + # Check for sparse format + is_sparse = sp.issparse(M) if is_sparse: - if sp.issparse(M): - if not isinstance(M, sp.coo_matrix): - M_coo = sp.coo_matrix(M) - else: - M_coo = M + # Convert to COO format for edge extraction + if not isinstance(M, sp.coo_matrix): + M_coo = sp.coo_matrix(M) + else: + M_coo = M - edge_sources = ( - M_coo.row if M_coo.row.dtype == np.int64 else M_coo.row.astype(np.int64) - ) - edge_targets = ( - M_coo.col if M_coo.col.dtype == np.int64 else M_coo.col.astype(np.int64) - ) - edge_costs = ( - M_coo.data - if M_coo.data.dtype == np.float64 - else M_coo.data.astype(np.float64) - ) - n1, n2 = M_coo.shape - elif isinstance(M, tuple) and len(M) == 3: - edge_sources = np.asarray(M[0], dtype=np.int64) - edge_targets = np.asarray(M[1], dtype=np.int64) - edge_costs = np.asarray(M[2], dtype=np.float64) - n1 = int(edge_sources.max() + 1) - n2 = int(edge_targets.max() + 1) + edge_sources = ( + M_coo.row if M_coo.row.dtype == np.uint64 else M_coo.row.astype(np.uint64) + ) + edge_targets = ( + M_coo.col if M_coo.col.dtype == np.uint64 else M_coo.col.astype(np.uint64) + ) + edge_costs = ( + M_coo.data + if M_coo.data.dtype == np.float64 + else M_coo.data.astype(np.float64) + ) + n1, n2 = M_coo.shape a, b = list_to_array(a, b) + elif isinstance(M, tuple): + raise ValueError( + "Tuple format for sparse cost matrix is not supported. " + "Please use scipy.sparse format (e.g., scipy.sparse.coo_matrix, csr_matrix, etc.)." + ) else: a, b, M = list_to_array(a, b, M) @@ -594,19 +601,15 @@ def emd2( else: type_as = a # Can't use M for sparse case + # Set n1, n2 if not already set (dense case) + if n1 is None: + n1, n2 = M.shape + # if empty array given then use uniform distributions if len(a) == 0: - a = ( - nx.ones((n1,), type_as=type_as) / n1 - if n1 - else nx.ones((M.shape[0],), type_as=type_as) / M.shape[0] - ) + a = nx.ones((n1,), type_as=type_as) / n1 if len(b) == 0: - b = ( - nx.ones((n2,), type_as=type_as) / n2 - if n2 - else nx.ones((M.shape[1],), type_as=type_as) / M.shape[1] - ) + b = nx.ones((n2,), type_as=type_as) / n2 a0, b0 = a, b M0 = None if is_sparse else M @@ -625,9 +628,6 @@ def emd2( a = np.asarray(a, dtype=np.float64) b = np.asarray(b, dtype=np.float64) - if n1 is None: - n1, n2 = M.shape - assert ( a.shape[0] == n1 and b.shape[0] == n2 ), "Dimension mismatch, check dimensions of M with a and b" @@ -646,127 +646,88 @@ def emd2( numThreads = check_number_threads(numThreads) - if log or return_matrix: + # ============================================================================ + # SPARSE SOLVER PATH + # ============================================================================ + if is_sparse: def f(b): bsel = b != 0 - if edge_sources is not None: - flow_sources, flow_targets, flow_values, cost, u, v, result_code = ( - emd_c_sparse( - a, b, edge_sources, edge_targets, edge_costs, numItermax - ) - ) + # Solve sparse EMD + flow_sources, flow_targets, flow_values, cost, u, v, result_code = ( + emd_c_sparse(a, b, edge_sources, edge_targets, edge_costs, numItermax) + ) - edge_to_idx = { - (edge_sources[k], edge_targets[k]): k - for k in range(len(edge_sources)) - } - - grad_edge_costs = np.zeros(len(edge_costs), dtype=np.float64) - for idx in range(len(flow_sources)): - src, tgt, flow = ( - flow_sources[idx], - flow_targets[idx], - flow_values[idx], - ) - edge_idx = edge_to_idx.get((src, tgt), -1) - if edge_idx >= 0: - grad_edge_costs[edge_idx] = flow - - if return_matrix: - G = np.zeros((len(a), len(b)), dtype=np.float64) - G[flow_sources, flow_targets] = flow_values - else: - G = None - else: - G, cost, u, v, result_code = emd_c(a, b, M, numItermax, numThreads) + # Build gradient mapping for edge costs + edge_to_idx = { + (edge_sources[k], edge_targets[k]): k for k in range(len(edge_sources)) + } + + grad_edge_costs = np.zeros(len(edge_costs), dtype=np.float64) + for idx in range(len(flow_sources)): + src, tgt, flow = ( + flow_sources[idx], + flow_targets[idx], + flow_values[idx], + ) + edge_idx = edge_to_idx.get((src, tgt), -1) + if edge_idx >= 0: + grad_edge_costs[edge_idx] = flow + # Center dual potentials if center_dual: u, v = center_ot_dual(u, v, a, b) if np.any(~asel) or np.any(~bsel): - if edge_sources is not None: - u, v = center_ot_dual(u, v, a, b) - else: - u, v = estimate_dual_null_weights(u, v, a, b, M) + u, v = center_ot_dual(u, v, a, b) - result_code_string = check_result(result_code) - log = {} - if not nx.is_floating_point(type_as): - warnings.warn( - "Input histogram consists of integer. The transport plan will be " - "casted accordingly, possibly resulting in a loss of precision. " - "If this behaviour is unwanted, please make sure your input " - "histogram consists of floating point elements.", - stacklevel=2, - ) + # Prepare cost with gradients + cost = nx.set_gradients( + nx.from_numpy(cost, type_as=type_as), + (a0, b0, edge_costs_original), + ( + nx.from_numpy(u - np.mean(u), type_as=type_as), + nx.from_numpy(v - np.mean(v), type_as=type_as), + nx.from_numpy(grad_edge_costs, type_as=type_as), + ), + ) + + check_result(result_code) + + if log or return_matrix: + log_dict = {} + log_dict["u"] = nx.from_numpy(u, type_as=type_as) + log_dict["v"] = nx.from_numpy(v, type_as=type_as) + log_dict["warning"] = check_result(result_code) + log_dict["result_code"] = result_code - if G is not None: - G = nx.from_numpy(G, type_as=type_as) if return_matrix: - log["G"] = G - log["u"] = nx.from_numpy(u, type_as=type_as) - log["v"] = nx.from_numpy(v, type_as=type_as) - log["warning"] = result_code_string - log["result_code"] = result_code - - if edge_sources is not None: - cost = nx.set_gradients( - nx.from_numpy(cost, type_as=type_as), - (a0, b0, edge_costs_original), - ( - log["u"] - nx.mean(log["u"]), - log["v"] - nx.mean(log["v"]), - nx.from_numpy(grad_edge_costs, type_as=type_as), - ), - ) + G = np.zeros((len(a), len(b)), dtype=np.float64) + G[flow_sources, flow_targets] = flow_values + log_dict["G"] = nx.from_numpy(G, type_as=type_as) + + return [cost, log_dict] else: - cost = nx.set_gradients( - nx.from_numpy(cost, type_as=type_as), - (a0, b0, M0), - (log["u"] - nx.mean(log["u"]), log["v"] - nx.mean(log["v"]), G), - ) - return [cost, log] + return cost + + # ============================================================================ + # DENSE SOLVER PATH + # ============================================================================ else: def f(b): bsel = b != 0 - if edge_sources is not None: - flow_sources, flow_targets, flow_values, cost, u, v, result_code = ( - emd_c_sparse( - a, b, edge_sources, edge_targets, edge_costs, numItermax - ) - ) - - edge_to_idx = { - (edge_sources[k], edge_targets[k]): k - for k in range(len(edge_sources)) - } - grad_edge_costs = np.zeros(len(edge_costs), dtype=np.float64) - for idx in range(len(flow_sources)): - src, tgt, flow = ( - flow_sources[idx], - flow_targets[idx], - flow_values[idx], - ) - edge_idx = edge_to_idx.get((src, tgt), -1) - if edge_idx >= 0: - grad_edge_costs[edge_idx] = flow - - G = None - else: - G, cost, u, v, result_code = emd_c(a, b, M, numItermax, numThreads) + # Solve dense EMD + G, cost, u, v, result_code = emd_c(a, b, M, numItermax, numThreads) + # Center dual potentials if center_dual: u, v = center_ot_dual(u, v, a, b) if np.any(~asel) or np.any(~bsel): - if edge_sources is not None: - u, v = center_ot_dual(u, v, a, b) - else: - u, v = estimate_dual_null_weights(u, v, a, b, M) + u, v = estimate_dual_null_weights(u, v, a, b, M) if not nx.is_floating_point(type_as): warnings.warn( @@ -777,31 +738,32 @@ def f(b): stacklevel=2, ) - if edge_sources is not None: - # Sparse: gradient w.r.t. edge_costs (no need to convert G) - cost = nx.set_gradients( - nx.from_numpy(cost, type_as=type_as), - (a0, b0, edge_costs_original), - ( - nx.from_numpy(u - np.mean(u), type_as=type_as), - nx.from_numpy(v - np.mean(v), type_as=type_as), - nx.from_numpy(grad_edge_costs, type_as=type_as), - ), - ) - else: - G = nx.from_numpy(G, type_as=type_as) - cost = nx.set_gradients( - nx.from_numpy(cost, type_as=type_as), - (a0, b0, M0), - ( - nx.from_numpy(u - np.mean(u), type_as=type_as), - nx.from_numpy(v - np.mean(v), type_as=type_as), - G, - ), - ) + G = nx.from_numpy(G, type_as=type_as) + cost = nx.set_gradients( + nx.from_numpy(cost, type_as=type_as), + (a0, b0, M0), + ( + nx.from_numpy(u - np.mean(u), type_as=type_as), + nx.from_numpy(v - np.mean(v), type_as=type_as), + G, + ), + ) check_result(result_code) - return cost + + if log or return_matrix: + log_dict = {} + log_dict["u"] = nx.from_numpy(u, type_as=type_as) + log_dict["v"] = nx.from_numpy(v, type_as=type_as) + log_dict["warning"] = check_result(result_code) + log_dict["result_code"] = result_code + + if return_matrix: + log_dict["G"] = G + + return [cost, log_dict] + else: + return cost if len(b.shape) == 1: return f(b) diff --git a/ot/lp/emd_wrap.pyx b/ot/lp/emd_wrap.pyx index f95e47433..3b19d3fdd 100644 --- a/ot/lp/emd_wrap.pyx +++ b/ot/lp/emd_wrap.pyx @@ -22,7 +22,7 @@ import warnings cdef extern from "EMD.h": int EMD_wrap(int n1,int n2, double *X, double *Y,double *D, double *G, double* alpha, double* beta, double *cost, uint64_t maxIter) nogil int EMD_wrap_omp(int n1,int n2, double *X, double *Y,double *D, double *G, double* alpha, double* beta, double *cost, uint64_t maxIter, int numThreads) nogil - int EMD_wrap_sparse(int n1, int n2, double *X, double *Y, uint64_t n_edges, int64_t *edge_sources, int64_t *edge_targets, double *edge_costs, int64_t *flow_sources_out, int64_t *flow_targets_out, double *flow_values_out, uint64_t *n_flows_out, double *alpha, double *beta, double *cost, uint64_t maxIter) nogil + int EMD_wrap_sparse(int n1, int n2, double *X, double *Y, uint64_t n_edges, uint64_t *edge_sources, uint64_t *edge_targets, double *edge_costs, uint64_t *flow_sources_out, uint64_t *flow_targets_out, double *flow_values_out, uint64_t *n_flows_out, double *alpha, double *beta, double *cost, uint64_t maxIter) nogil cdef enum ProblemType: INFEASIBLE, OPTIMAL, UNBOUNDED, MAX_ITER_REACHED @@ -212,8 +212,8 @@ def emd_1d_sorted(np.ndarray[double, ndim=1, mode="c"] u_weights, @cython.wraparound(False) def emd_c_sparse(np.ndarray[double, ndim=1, mode="c"] a, np.ndarray[double, ndim=1, mode="c"] b, - np.ndarray[int64_t, ndim=1, mode="c"] edge_sources, - np.ndarray[int64_t, ndim=1, mode="c"] edge_targets, + np.ndarray[uint64_t, ndim=1, mode="c"] edge_sources, + np.ndarray[uint64_t, ndim=1, mode="c"] edge_targets, np.ndarray[double, ndim=1, mode="c"] edge_costs, uint64_t max_iter): """ @@ -225,9 +225,9 @@ def emd_c_sparse(np.ndarray[double, ndim=1, mode="c"] a, Source histogram b : (n2,) array Target histogram - edge_sources : (k,) array, int64 + edge_sources : (k,) array, uint64 Source indices for each edge - edge_targets : (k,) array, int64 + edge_targets : (k,) array, uint64 Target indices for each edge edge_costs : (k,) array, float64 Cost for each edge @@ -236,9 +236,9 @@ def emd_c_sparse(np.ndarray[double, ndim=1, mode="c"] a, Returns ------- - flow_sources : (n_flows,) array, int64 + flow_sources : (n_flows,) array, uint64 Source indices of non-zero flows - flow_targets : (n_flows,) array, int64 + flow_targets : (n_flows,) array, uint64 Target indices of non-zero flows flow_values : (n_flows,) array, float64 Flow values @@ -259,8 +259,8 @@ def emd_c_sparse(np.ndarray[double, ndim=1, mode="c"] a, cdef double cost = 0 # Allocate output arrays (max size = n_edges) - cdef np.ndarray[int64_t, ndim=1, mode="c"] flow_sources = np.zeros(n_edges, dtype=np.int64) - cdef np.ndarray[int64_t, ndim=1, mode="c"] flow_targets = np.zeros(n_edges, dtype=np.int64) + cdef np.ndarray[uint64_t, ndim=1, mode="c"] flow_sources = np.zeros(n_edges, dtype=np.uint64) + cdef np.ndarray[uint64_t, ndim=1, mode="c"] flow_targets = np.zeros(n_edges, dtype=np.uint64) cdef np.ndarray[double, ndim=1, mode="c"] flow_values = np.zeros(n_edges, dtype=np.float64) cdef np.ndarray[double, ndim=1, mode="c"] alpha = np.zeros(n1) cdef np.ndarray[double, ndim=1, mode="c"] beta = np.zeros(n2) @@ -270,8 +270,8 @@ def emd_c_sparse(np.ndarray[double, ndim=1, mode="c"] a, n1, n2, a.data, b.data, n_edges, - edge_sources.data, edge_targets.data, edge_costs.data, - flow_sources.data, flow_targets.data, flow_values.data, + edge_sources.data, edge_targets.data, edge_costs.data, + flow_sources.data, flow_targets.data, flow_values.data, &n_flows_out, alpha.data, beta.data, &cost, max_iter ) diff --git a/test/test_ot.py b/test/test_ot.py index 6b57c8602..27c9200d8 100644 --- a/test/test_ot.py +++ b/test/test_ot.py @@ -1089,6 +1089,172 @@ def test_emd2_sparse_vs_dense(): np.testing.assert_allclose(cost_dense, cost_sparse, rtol=1e-5, atol=1e-7) +def test_emd_sparse_backends(nx): + """Test that sparse EMD works with different backends for weights a and b. + + Uses augmented k-NN graph approach to ensure feasibility. + """ + n_source = 50 + n_target = 50 + k = 10 + + rng = np.random.RandomState(42) + + # Create distributions + a = ot.utils.unif(n_source) + b = ot.utils.unif(n_target) + + # Create cost matrix + x_source = rng.randn(n_source, 2) + x_target = rng.randn(n_target, 2) + 0.5 + C = ot.dist(x_source, x_target) + + # Create sparse k-NN graph + rows = [] + cols = [] + data = [] + + for i in range(n_source): + distances = C[i, :] + nearest_k = np.argpartition(distances, k)[:k] + for j in nearest_k: + rows.append(i) + cols.append(j) + data.append(C[i, j]) + + C_knn = coo_matrix((data, (rows, cols)), shape=(n_source, n_target)) + + # Augment with necessary edges (same approach as test_emd_sparse_vs_dense) + large_cost = 1e8 + C_dense_infty = np.full((n_source, n_target), large_cost) + C_knn_array = C_knn.toarray() + C_dense_infty[C_knn_array > 0] = C_knn_array[C_knn_array > 0] + + G_dense_initial = ot.emd(a, b, C_dense_infty) + eps = 1e-9 + active_mask = G_dense_initial > eps + knn_mask = C_knn_array > 0 + extra_edges_mask = active_mask & ~knn_mask + + rows_aug = [] + cols_aug = [] + data_aug = [] + + knn_rows, knn_cols = np.where(knn_mask) + for i, j in zip(knn_rows, knn_cols): + rows_aug.append(i) + cols_aug.append(j) + data_aug.append(C[i, j]) + + extra_rows, extra_cols = np.where(extra_edges_mask) + for i, j in zip(extra_rows, extra_cols): + rows_aug.append(i) + cols_aug.append(j) + data_aug.append(C[i, j]) + + C_augmented = coo_matrix( + (data_aug, (rows_aug, cols_aug)), shape=(n_source, n_target) + ) + + # Test with numpy weights (baseline) + _, log_np = ot.emd(a, b, C_augmented, log=True) + + # Test with backend weights + ab, bb = nx.from_numpy(a, b) + _, log_backend = ot.emd(ab, bb, C_augmented, log=True) + + # Compare costs + cost_np = log_np["cost"] + cost_backend = nx.to_numpy(log_backend["cost"]) + + np.testing.assert_allclose(cost_np, cost_backend, rtol=1e-5, atol=1e-7) + + # Check flow values match + np.testing.assert_allclose( + log_np["flow_values"], log_backend["flow_values"], rtol=1e-5, atol=1e-7 + ) + + +def test_emd2_sparse_backends(nx): + """Test that sparse emd2 works with different backends for weights a and b. + + Uses augmented k-NN graph approach to ensure feasibility. + """ + n_source = 50 + n_target = 50 + k = 10 + + rng = np.random.RandomState(42) + + # Create distributions + a = ot.utils.unif(n_source) + b = ot.utils.unif(n_target) + + # Create cost matrix + x_source = rng.randn(n_source, 2) + x_target = rng.randn(n_target, 2) + 0.5 + C = ot.dist(x_source, x_target) + + # Create sparse k-NN graph + rows = [] + cols = [] + data = [] + + for i in range(n_source): + distances = C[i, :] + nearest_k = np.argpartition(distances, k)[:k] + for j in nearest_k: + rows.append(i) + cols.append(j) + data.append(C[i, j]) + + C_knn = coo_matrix((data, (rows, cols)), shape=(n_source, n_target)) + + # Augment with necessary edges (same approach as test_emd2_sparse_vs_dense) + large_cost = 1e8 + C_dense_infty = np.full((n_source, n_target), large_cost) + C_knn_array = C_knn.toarray() + C_dense_infty[C_knn_array > 0] = C_knn_array[C_knn_array > 0] + + G_dense_initial = ot.emd(a, b, C_dense_infty) + eps = 1e-9 + active_mask = G_dense_initial > eps + knn_mask = C_knn_array > 0 + extra_edges_mask = active_mask & ~knn_mask + + rows_aug = [] + cols_aug = [] + data_aug = [] + + knn_rows, knn_cols = np.where(knn_mask) + for i, j in zip(knn_rows, knn_cols): + rows_aug.append(i) + cols_aug.append(j) + data_aug.append(C[i, j]) + + extra_rows, extra_cols = np.where(extra_edges_mask) + for i, j in zip(extra_rows, extra_cols): + rows_aug.append(i) + cols_aug.append(j) + data_aug.append(C[i, j]) + + C_augmented = coo_matrix( + (data_aug, (rows_aug, cols_aug)), shape=(n_source, n_target) + ) + + # Test with numpy weights (baseline) + cost_np = ot.emd2(a, b, C_augmented) + + # Test with backend weights + ab, bb = nx.from_numpy(a, b) + cost_backend = ot.emd2(ab, bb, C_augmented) + + # Compare costs + cost_backend_np = nx.to_numpy(cost_backend) + + np.testing.assert_allclose(cost_np, cost_backend_np, rtol=1e-5, atol=1e-7) + + def check_duality_gap(a, b, M, G, u, v, cost): cost_dual = np.vdot(a, u) + np.vdot(b, v) # Check that dual and primal cost are equal From fae9f026a8abd85511de0903fc0409bbfcf1acf7 Mon Sep 17 00:00:00 2001 From: nathanneike Date: Mon, 3 Nov 2025 13:48:49 +0100 Subject: [PATCH 5/5] fix : Quick test file fix --- test/test_ot.py | 13 ------------- 1 file changed, 13 deletions(-) diff --git a/test/test_ot.py b/test/test_ot.py index 27c9200d8..691f77382 100644 --- a/test/test_ot.py +++ b/test/test_ot.py @@ -1100,16 +1100,13 @@ def test_emd_sparse_backends(nx): rng = np.random.RandomState(42) - # Create distributions a = ot.utils.unif(n_source) b = ot.utils.unif(n_target) - # Create cost matrix x_source = rng.randn(n_source, 2) x_target = rng.randn(n_target, 2) + 0.5 C = ot.dist(x_source, x_target) - # Create sparse k-NN graph rows = [] cols = [] data = [] @@ -1156,20 +1153,16 @@ def test_emd_sparse_backends(nx): (data_aug, (rows_aug, cols_aug)), shape=(n_source, n_target) ) - # Test with numpy weights (baseline) _, log_np = ot.emd(a, b, C_augmented, log=True) - # Test with backend weights ab, bb = nx.from_numpy(a, b) _, log_backend = ot.emd(ab, bb, C_augmented, log=True) - # Compare costs cost_np = log_np["cost"] cost_backend = nx.to_numpy(log_backend["cost"]) np.testing.assert_allclose(cost_np, cost_backend, rtol=1e-5, atol=1e-7) - # Check flow values match np.testing.assert_allclose( log_np["flow_values"], log_backend["flow_values"], rtol=1e-5, atol=1e-7 ) @@ -1186,16 +1179,13 @@ def test_emd2_sparse_backends(nx): rng = np.random.RandomState(42) - # Create distributions a = ot.utils.unif(n_source) b = ot.utils.unif(n_target) - # Create cost matrix x_source = rng.randn(n_source, 2) x_target = rng.randn(n_target, 2) + 0.5 C = ot.dist(x_source, x_target) - # Create sparse k-NN graph rows = [] cols = [] data = [] @@ -1242,14 +1232,11 @@ def test_emd2_sparse_backends(nx): (data_aug, (rows_aug, cols_aug)), shape=(n_source, n_target) ) - # Test with numpy weights (baseline) cost_np = ot.emd2(a, b, C_augmented) - # Test with backend weights ab, bb = nx.from_numpy(a, b) cost_backend = ot.emd2(ab, bb, C_augmented) - # Compare costs cost_backend_np = nx.to_numpy(cost_backend) np.testing.assert_allclose(cost_np, cost_backend_np, rtol=1e-5, atol=1e-7)