diff --git a/MANIFEST.in b/MANIFEST.in index 7c96ba026..d93298de4 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -10,4 +10,5 @@ include ot/lp/full_bipartitegraph.h include ot/lp/full_bipartitegraph_omp.h include ot/lp/network_simplex_simple.h include ot/lp/network_simplex_simple_omp.h +include ot/lp/sparse_bipartitegraph.h include ot/partial/partial_cython.pyx diff --git a/ot/lp/EMD.h b/ot/lp/EMD.h index b56f0601b..e3564a2d2 100644 --- a/ot/lp/EMD.h +++ b/ot/lp/EMD.h @@ -32,6 +32,24 @@ enum ProblemType { int EMD_wrap(int n1,int n2, double *X, double *Y,double *D, double *G, double* alpha, double* beta, double *cost, uint64_t maxIter); int EMD_wrap_omp(int n1,int n2, double *X, double *Y,double *D, double *G, double* alpha, double* beta, double *cost, uint64_t maxIter, int numThreads); +int EMD_wrap_sparse( + int n1, + int n2, + double *X, + double *Y, + uint64_t n_edges, // Number of edges in sparse graph + uint64_t *edge_sources, // Source indices for each edge (n_edges) + uint64_t *edge_targets, // Target indices for each edge (n_edges) + double *edge_costs, // Cost for each edge (n_edges) + uint64_t *flow_sources_out, // Output: source indices of non-zero flows + uint64_t *flow_targets_out, // Output: target indices of non-zero flows + double *flow_values_out, // Output: flow values + uint64_t *n_flows_out, + double *alpha, // Output: dual variables for sources (n1) + double *beta, // Output: dual variables for targets (n2) + double *cost, // Output: total transportation cost + uint64_t maxIter // Maximum iterations for solver +); #endif diff --git a/ot/lp/EMD_wrapper.cpp b/ot/lp/EMD_wrapper.cpp index 4aa5a6e72..bd3672535 100644 --- a/ot/lp/EMD_wrapper.cpp +++ b/ot/lp/EMD_wrapper.cpp @@ -15,8 +15,10 @@ #include "network_simplex_simple.h" #include "network_simplex_simple_omp.h" +#include "sparse_bipartitegraph.h" #include "EMD.h" #include +#include int EMD_wrap(int n1, int n2, double *X, double *Y, double *D, double *G, @@ -216,3 +218,156 @@ int EMD_wrap_omp(int n1, int n2, double *X, double *Y, double *D, double *G, return ret; } + +// ============================================================================ +// SPARSE VERSION: Accepts edge list instead of dense cost matrix +// ============================================================================ +int EMD_wrap_sparse( + int n1, + int n2, + double *X, + double *Y, + uint64_t n_edges, + uint64_t *edge_sources, + uint64_t *edge_targets, + double *edge_costs, + uint64_t *flow_sources_out, + uint64_t *flow_targets_out, + double *flow_values_out, + uint64_t *n_flows_out, + double *alpha, + double *beta, + double *cost, + uint64_t maxIter +) { + using namespace lemon; + + uint64_t n = 0; + for (int i = 0; i < n1; i++) { + double val = *(X + i); + if (val > 0) { + n++; + } else if (val < 0) { + return INFEASIBLE; + } + } + + uint64_t m = 0; + for (int i = 0; i < n2; i++) { + double val = *(Y + i); + if (val > 0) { + m++; + } else if (val < 0) { + return INFEASIBLE; + } + } + + std::vector indI(n); // indI[graph_idx] = original_source_idx + std::vector indJ(m); // indJ[graph_idx] = original_target_idx + std::vector weights1(n); // Source masses (positive only) + std::vector weights2(m); // Target masses (negative for demand) + + // Create reverse mapping: original_idx → graph_idx + std::vector source_to_graph(n1, -1); + std::vector target_to_graph(n2, -1); + + uint64_t cur = 0; + for (int i = 0; i < n1; i++) { + double val = *(X + i); + if (val > 0) { + weights1[cur] = val; // Store the mass + indI[cur] = i; // Forward map: graph → original + source_to_graph[i] = cur; // Reverse map: original → graph + cur++; + } + } + + cur = 0; + for (int i = 0; i < n2; i++) { + double val = *(Y + i); + if (val > 0) { + weights2[cur] = -val; + indJ[cur] = i; // Forward map: graph → original + target_to_graph[i] = cur; // Reverse map: original → graph + cur++; + } + } + + typedef SparseBipartiteDigraph Digraph; + DIGRAPH_TYPEDEFS(Digraph); + + Digraph di(n, m); + + std::vector> edges; // (source, target) pairs + std::vector edge_to_arc; // edge_to_arc[k] = arc ID for edge k + std::vector arc_costs; // arc_costs[arc_id] = cost (for O(1) lookup) + edges.reserve(n_edges); + edge_to_arc.reserve(n_edges); + + uint64_t valid_edge_count = 0; + for (uint64_t k = 0; k < n_edges; k++) { + int64_t src_orig = edge_sources[k]; + int64_t tgt_orig = edge_targets[k]; + int64_t src = source_to_graph[src_orig]; + int64_t tgt = target_to_graph[tgt_orig]; + + if (src >= 0 && tgt >= 0) { + edges.emplace_back(src, tgt + n); + edge_to_arc.push_back(valid_edge_count); + arc_costs.push_back(edge_costs[k]); // Store cost indexed by arc ID + valid_edge_count++; + } else { + edge_to_arc.push_back(UINT64_MAX); + } + } + + + di.buildFromEdges(edges); + + NetworkSimplexSimple net( + di, true, (int)(n + m), di.arcNum(), maxIter + ); + + net.supplyMap(&weights1[0], (int)n, &weights2[0], (int)m); + + for (uint64_t k = 0; k < n_edges; k++) { + if (edge_to_arc[k] != UINT64_MAX) { + net.setCost(edge_to_arc[k], edge_costs[k]); + } + } + + int ret = net.run(); + + if (ret == (int)net.OPTIMAL || ret == (int)net.MAX_ITER_REACHED) { + *cost = 0; + *n_flows_out = 0; + + Arc a; + di.first(a); + for (; a != INVALID; di.next(a)) { + uint64_t i = di.source(a); + uint64_t j = di.target(a); + double flow = net.flow(a); + + uint64_t orig_i = indI[i]; + uint64_t orig_j = indJ[j - n]; + + + double arc_cost = arc_costs[a]; + + *cost += flow * arc_cost; + + + *(alpha + orig_i) = -net.potential(i); + *(beta + orig_j) = net.potential(j); + + if (flow > 1e-15) { + flow_sources_out[*n_flows_out] = orig_i; + flow_targets_out[*n_flows_out] = orig_j; + flow_values_out[*n_flows_out] = flow; + (*n_flows_out)++; + } + } + } + return ret; +} \ No newline at end of file diff --git a/ot/lp/_network_simplex.py b/ot/lp/_network_simplex.py index 492e4c7ac..8d033679c 100644 --- a/ot/lp/_network_simplex.py +++ b/ot/lp/_network_simplex.py @@ -11,9 +11,10 @@ import numpy as np import warnings +import scipy.sparse as sp from ..utils import list_to_array, check_number_threads from ..backend import get_backend -from .emd_wrap import emd_c, check_result +from .emd_wrap import emd_c, emd_c_sparse, check_result def center_ot_dual(alpha0, beta0, a=None, b=None): @@ -233,15 +234,24 @@ def emd( If True, checks that the marginals mass are equal. If False, skips the check. + .. note:: The solver automatically detects sparse format when M is provided as + a scipy.sparse matrix (coo, csr, csc, etc.). + + For sparse inputs, the solver uses a memory-efficient algorithm and returns + the flow in edge format (via log dict) instead of a full matrix. + Returns ------- - gamma: array-like, shape (ns, nt) - Optimal transportation matrix for the given - parameters + gamma: array-like, shape (ns, nt), or None + Optimal transportation matrix for the given parameters. + For sparse inputs, returns None (use log=True to get flow in edge format). log: dict, optional - If input log is true, a dictionary containing the - cost and dual variables and exit status + If input log is True, a dictionary containing the cost, dual variables, + and exit status. For sparse inputs with log=True, also contains: + - 'flow_sources': source nodes of flow edges + - 'flow_targets': target nodes of flow edges + - 'flow_values': flow values on edges Examples @@ -272,38 +282,72 @@ def emd( ot.optim.cg : General regularized OT """ - a, b, M = list_to_array(a, b, M) - nx = get_backend(M, a, b) + edge_sources = None + edge_targets = None + edge_costs = None + n1, n2 = None, None + + # Check for sparse format + is_sparse = sp.issparse(M) + + if is_sparse: + # Convert to COO format for edge extraction + if not isinstance(M, sp.coo_matrix): + M_coo = sp.coo_matrix(M) + else: + M_coo = M + + edge_sources = ( + M_coo.row if M_coo.row.dtype == np.uint64 else M_coo.row.astype(np.uint64) + ) + edge_targets = ( + M_coo.col if M_coo.col.dtype == np.uint64 else M_coo.col.astype(np.uint64) + ) + edge_costs = ( + M_coo.data + if M_coo.data.dtype == np.float64 + else M_coo.data.astype(np.float64) + ) + n1, n2 = M_coo.shape + + a, b = list_to_array(a, b) + elif isinstance(M, tuple): + raise ValueError( + "Tuple format for sparse cost matrix is not supported. " + "Please use scipy.sparse format (e.g., scipy.sparse.coo_matrix, csr_matrix, etc.)." + ) + else: + a, b, M = list_to_array(a, b, M) + + nx = get_backend(a, b) if len(a) != 0: type_as = a elif len(b) != 0: type_as = b else: - type_as = M + type_as = a + + # Set n1, n2 if not already set (dense case) + if n1 is None: + n1, n2 = M.shape - # if empty array given then use uniform distributions if len(a) == 0: - a = nx.ones((M.shape[0],), type_as=type_as) / M.shape[0] + a = nx.ones((n1,), type_as=type_as) / n1 if len(b) == 0: - b = nx.ones((M.shape[1],), type_as=type_as) / M.shape[1] + b = nx.ones((n2,), type_as=type_as) / n2 - # convert to numpy - M, a, b = nx.to_numpy(M, a, b) + if is_sparse: + a, b = nx.to_numpy(a, b) + else: + M, a, b = nx.to_numpy(M, a, b) + M = np.asarray(M, dtype=np.float64, order="C") - # ensure float64 a = np.asarray(a, dtype=np.float64) b = np.asarray(b, dtype=np.float64) - M = np.asarray(M, dtype=np.float64, order="C") - - # if empty array given then use uniform distributions - if len(a) == 0: - a = np.ones((M.shape[0],), dtype=np.float64) / M.shape[0] - if len(b) == 0: - b = np.ones((M.shape[1],), dtype=np.float64) / M.shape[1] assert ( - a.shape[0] == M.shape[0] and b.shape[0] == M.shape[1] + a.shape[0] == n1 and b.shape[0] == n2 ), "Dimension mismatch, check dimensions of M with a and b" # ensure that same mass @@ -321,32 +365,77 @@ def emd( numThreads = check_number_threads(numThreads) - G, cost, u, v, result_code = emd_c(a, b, M, numItermax, numThreads) + # ============================================================================ + # SPARSE SOLVER PATH + # ============================================================================ + if is_sparse: + # Sparse solver - never build full matrix + flow_sources, flow_targets, flow_values, cost, u, v, result_code = emd_c_sparse( + a, b, edge_sources, edge_targets, edge_costs, numItermax + ) - if center_dual: - u, v = center_ot_dual(u, v, a, b) + # Center dual potentials + if center_dual: + u, v = center_ot_dual(u, v, a, b) + + if np.any(~asel) or np.any(~bsel): + u, v = center_ot_dual(u, v, a, b) + + result_code_string = check_result(result_code) + + if log: + log_dict = {} + log_dict["cost"] = cost + log_dict["u"] = nx.from_numpy(u, type_as=type_as) + log_dict["v"] = nx.from_numpy(v, type_as=type_as) + log_dict["warning"] = result_code_string + log_dict["result_code"] = result_code + log_dict["flow_sources"] = flow_sources + log_dict["flow_targets"] = flow_targets + log_dict["flow_values"] = flow_values + + return None, log_dict + else: + raise ValueError( + "For sparse inputs, log=True is required to get the flow in edge format" + ) - if np.any(~asel) or np.any(~bsel): - u, v = estimate_dual_null_weights(u, v, a, b, M) + # ============================================================================ + # DENSE SOLVER PATH + # ============================================================================ + else: + # Dense solver + G, cost, u, v, result_code = emd_c(a, b, M, numItermax, numThreads) - result_code_string = check_result(result_code) - if not nx.is_floating_point(type_as): - warnings.warn( - "Input histogram consists of integer. The transport plan will be " - "casted accordingly, possibly resulting in a loss of precision. " - "If this behaviour is unwanted, please make sure your input " - "histogram consists of floating point elements.", - stacklevel=2, - ) - if log: - log = {} - log["cost"] = cost - log["u"] = nx.from_numpy(u, type_as=type_as) - log["v"] = nx.from_numpy(v, type_as=type_as) - log["warning"] = result_code_string - log["result_code"] = result_code - return nx.from_numpy(G, type_as=type_as), log - return nx.from_numpy(G, type_as=type_as) + # Center dual potentials + if center_dual: + u, v = center_ot_dual(u, v, a, b) + + if np.any(~asel) or np.any(~bsel): + u, v = estimate_dual_null_weights(u, v, a, b, M) + + result_code_string = check_result(result_code) + + if not nx.is_floating_point(type_as): + warnings.warn( + "Input histogram consists of integer. The transport plan will be " + "casted accordingly, possibly resulting in a loss of precision. " + "If this behaviour is unwanted, please make sure your input " + "histogram consists of floating point elements.", + stacklevel=2, + ) + + if log: + log_dict = {} + log_dict["cost"] = cost + log_dict["u"] = nx.from_numpy(u, type_as=type_as) + log_dict["v"] = nx.from_numpy(v, type_as=type_as) + log_dict["warning"] = result_code_string + log_dict["result_code"] = result_code + + return nx.from_numpy(G, type_as=type_as), log_dict + else: + return nx.from_numpy(G, type_as=type_as) def emd2( @@ -356,10 +445,10 @@ def emd2( processes=1, numItermax=100000, log=False, - return_matrix=False, center_dual=True, numThreads=1, check_marginals=True, + return_matrix=False, ): r"""Solves the Earth Movers distance problem and returns the loss @@ -421,6 +510,12 @@ def emd2( If True, checks that the marginals mass are equal. If False, skips the check. + .. note:: The solver automatically detects sparse format when M is provided as + a scipy.sparse matrix (coo, csr, csc, etc.). + + For sparse inputs, the solver uses a memory-efficient algorithm. + Edges not included are treated as having infinite cost (forbidden). + Returns ------- @@ -460,34 +555,81 @@ def emd2( ot.optim.cg : General regularized OT """ - a, b, M = list_to_array(a, b, M) - nx = get_backend(M, a, b) + edge_sources = None + edge_targets = None + edge_costs = None + n1, n2 = None, None + + # Check for sparse format + is_sparse = sp.issparse(M) + + if is_sparse: + # Convert to COO format for edge extraction + if not isinstance(M, sp.coo_matrix): + M_coo = sp.coo_matrix(M) + else: + M_coo = M + + edge_sources = ( + M_coo.row if M_coo.row.dtype == np.uint64 else M_coo.row.astype(np.uint64) + ) + edge_targets = ( + M_coo.col if M_coo.col.dtype == np.uint64 else M_coo.col.astype(np.uint64) + ) + edge_costs = ( + M_coo.data + if M_coo.data.dtype == np.float64 + else M_coo.data.astype(np.float64) + ) + n1, n2 = M_coo.shape + + a, b = list_to_array(a, b) + elif isinstance(M, tuple): + raise ValueError( + "Tuple format for sparse cost matrix is not supported. " + "Please use scipy.sparse format (e.g., scipy.sparse.coo_matrix, csr_matrix, etc.)." + ) + else: + a, b, M = list_to_array(a, b, M) + + nx = get_backend(a, b) if len(a) != 0: type_as = a elif len(b) != 0: type_as = b else: - type_as = M + type_as = a # Can't use M for sparse case + + # Set n1, n2 if not already set (dense case) + if n1 is None: + n1, n2 = M.shape # if empty array given then use uniform distributions if len(a) == 0: - a = nx.ones((M.shape[0],), type_as=type_as) / M.shape[0] + a = nx.ones((n1,), type_as=type_as) / n1 if len(b) == 0: - b = nx.ones((M.shape[1],), type_as=type_as) / M.shape[1] + b = nx.ones((n2,), type_as=type_as) / n2 - # store original tensors - a0, b0, M0 = a, b, M + a0, b0 = a, b + M0 = None if is_sparse else M - # convert to numpy - M, a, b = nx.to_numpy(M, a, b) + if is_sparse: + edge_costs_original = nx.from_numpy(edge_costs, type_as=type_as) + else: + edge_costs_original = None + + if is_sparse: + a, b = nx.to_numpy(a, b) + else: + M, a, b = nx.to_numpy(M, a, b) + M = np.asarray(M, dtype=np.float64, order="C") a = np.asarray(a, dtype=np.float64) b = np.asarray(b, dtype=np.float64) - M = np.asarray(M, dtype=np.float64, order="C") assert ( - a.shape[0] == M.shape[0] and b.shape[0] == M.shape[1] + a.shape[0] == n1 and b.shape[0] == n2 ), "Dimension mismatch, check dimensions of M with a and b" # ensure that same mass @@ -504,48 +646,83 @@ def emd2( numThreads = check_number_threads(numThreads) - if log or return_matrix: + # ============================================================================ + # SPARSE SOLVER PATH + # ============================================================================ + if is_sparse: def f(b): bsel = b != 0 - G, cost, u, v, result_code = emd_c(a, b, M, numItermax, numThreads) + # Solve sparse EMD + flow_sources, flow_targets, flow_values, cost, u, v, result_code = ( + emd_c_sparse(a, b, edge_sources, edge_targets, edge_costs, numItermax) + ) + # Build gradient mapping for edge costs + edge_to_idx = { + (edge_sources[k], edge_targets[k]): k for k in range(len(edge_sources)) + } + + grad_edge_costs = np.zeros(len(edge_costs), dtype=np.float64) + for idx in range(len(flow_sources)): + src, tgt, flow = ( + flow_sources[idx], + flow_targets[idx], + flow_values[idx], + ) + edge_idx = edge_to_idx.get((src, tgt), -1) + if edge_idx >= 0: + grad_edge_costs[edge_idx] = flow + + # Center dual potentials if center_dual: u, v = center_ot_dual(u, v, a, b) if np.any(~asel) or np.any(~bsel): - u, v = estimate_dual_null_weights(u, v, a, b, M) + u, v = center_ot_dual(u, v, a, b) - result_code_string = check_result(result_code) - log = {} - if not nx.is_floating_point(type_as): - warnings.warn( - "Input histogram consists of integer. The transport plan will be " - "casted accordingly, possibly resulting in a loss of precision. " - "If this behaviour is unwanted, please make sure your input " - "histogram consists of floating point elements.", - stacklevel=2, - ) - G = nx.from_numpy(G, type_as=type_as) - if return_matrix: - log["G"] = G - log["u"] = nx.from_numpy(u, type_as=type_as) - log["v"] = nx.from_numpy(v, type_as=type_as) - log["warning"] = result_code_string - log["result_code"] = result_code + # Prepare cost with gradients cost = nx.set_gradients( nx.from_numpy(cost, type_as=type_as), - (a0, b0, M0), - (log["u"] - nx.mean(log["u"]), log["v"] - nx.mean(log["v"]), G), + (a0, b0, edge_costs_original), + ( + nx.from_numpy(u - np.mean(u), type_as=type_as), + nx.from_numpy(v - np.mean(v), type_as=type_as), + nx.from_numpy(grad_edge_costs, type_as=type_as), + ), ) - return [cost, log] + + check_result(result_code) + + if log or return_matrix: + log_dict = {} + log_dict["u"] = nx.from_numpy(u, type_as=type_as) + log_dict["v"] = nx.from_numpy(v, type_as=type_as) + log_dict["warning"] = check_result(result_code) + log_dict["result_code"] = result_code + + if return_matrix: + G = np.zeros((len(a), len(b)), dtype=np.float64) + G[flow_sources, flow_targets] = flow_values + log_dict["G"] = nx.from_numpy(G, type_as=type_as) + + return [cost, log_dict] + else: + return cost + + # ============================================================================ + # DENSE SOLVER PATH + # ============================================================================ else: def f(b): bsel = b != 0 + + # Solve dense EMD G, cost, u, v, result_code = emd_c(a, b, M, numItermax, numThreads) + # Center dual potentials if center_dual: u, v = center_ot_dual(u, v, a, b) @@ -560,6 +737,7 @@ def f(b): "histogram consists of floating point elements.", stacklevel=2, ) + G = nx.from_numpy(G, type_as=type_as) cost = nx.set_gradients( nx.from_numpy(cost, type_as=type_as), @@ -572,7 +750,20 @@ def f(b): ) check_result(result_code) - return cost + + if log or return_matrix: + log_dict = {} + log_dict["u"] = nx.from_numpy(u, type_as=type_as) + log_dict["v"] = nx.from_numpy(v, type_as=type_as) + log_dict["warning"] = check_result(result_code) + log_dict["result_code"] = result_code + + if return_matrix: + log_dict["G"] = G + + return [cost, log_dict] + else: + return cost if len(b.shape) == 1: return f(b) diff --git a/ot/lp/emd_wrap.pyx b/ot/lp/emd_wrap.pyx index 53df54fc3..3b19d3fdd 100644 --- a/ot/lp/emd_wrap.pyx +++ b/ot/lp/emd_wrap.pyx @@ -14,7 +14,7 @@ from ..utils import dist cimport cython cimport libc.math as math -from libc.stdint cimport uint64_t +from libc.stdint cimport uint64_t, int64_t import warnings @@ -22,6 +22,7 @@ import warnings cdef extern from "EMD.h": int EMD_wrap(int n1,int n2, double *X, double *Y,double *D, double *G, double* alpha, double* beta, double *cost, uint64_t maxIter) nogil int EMD_wrap_omp(int n1,int n2, double *X, double *Y,double *D, double *G, double* alpha, double* beta, double *cost, uint64_t maxIter, int numThreads) nogil + int EMD_wrap_sparse(int n1, int n2, double *X, double *Y, uint64_t n_edges, uint64_t *edge_sources, uint64_t *edge_targets, double *edge_costs, uint64_t *flow_sources_out, uint64_t *flow_targets_out, double *flow_values_out, uint64_t *n_flows_out, double *alpha, double *beta, double *cost, uint64_t maxIter) nogil cdef enum ProblemType: INFEASIBLE, OPTIMAL, UNBOUNDED, MAX_ITER_REACHED @@ -206,3 +207,78 @@ def emd_1d_sorted(np.ndarray[double, ndim=1, mode="c"] u_weights, cur_idx += 1 cur_idx += 1 return G[:cur_idx], indices[:cur_idx], cost + +@cython.boundscheck(False) +@cython.wraparound(False) +def emd_c_sparse(np.ndarray[double, ndim=1, mode="c"] a, + np.ndarray[double, ndim=1, mode="c"] b, + np.ndarray[uint64_t, ndim=1, mode="c"] edge_sources, + np.ndarray[uint64_t, ndim=1, mode="c"] edge_targets, + np.ndarray[double, ndim=1, mode="c"] edge_costs, + uint64_t max_iter): + """ + Sparse EMD solver - only considers edges in edge_sources/edge_targets + + Parameters + ---------- + a : (n1,) array + Source histogram + b : (n2,) array + Target histogram + edge_sources : (k,) array, uint64 + Source indices for each edge + edge_targets : (k,) array, uint64 + Target indices for each edge + edge_costs : (k,) array, float64 + Cost for each edge + max_iter : uint64_t + Maximum iterations + + Returns + ------- + flow_sources : (n_flows,) array, uint64 + Source indices of non-zero flows + flow_targets : (n_flows,) array, uint64 + Target indices of non-zero flows + flow_values : (n_flows,) array, float64 + Flow values + cost : float + Total cost + alpha : (n1,) array + Dual variables for sources + beta : (n2,) array + Dual variables for targets + result_code : int + Result status + """ + cdef int n1 = a.shape[0] + cdef int n2 = b.shape[0] + cdef uint64_t n_edges = edge_sources.shape[0] + cdef uint64_t n_flows_out = 0 + cdef int result_code = 0 + cdef double cost = 0 + + # Allocate output arrays (max size = n_edges) + cdef np.ndarray[uint64_t, ndim=1, mode="c"] flow_sources = np.zeros(n_edges, dtype=np.uint64) + cdef np.ndarray[uint64_t, ndim=1, mode="c"] flow_targets = np.zeros(n_edges, dtype=np.uint64) + cdef np.ndarray[double, ndim=1, mode="c"] flow_values = np.zeros(n_edges, dtype=np.float64) + cdef np.ndarray[double, ndim=1, mode="c"] alpha = np.zeros(n1) + cdef np.ndarray[double, ndim=1, mode="c"] beta = np.zeros(n2) + + with nogil: + result_code = EMD_wrap_sparse( + n1, n2, + a.data, b.data, + n_edges, + edge_sources.data, edge_targets.data, edge_costs.data, + flow_sources.data, flow_targets.data, flow_values.data, + &n_flows_out, + alpha.data, beta.data, &cost, max_iter + ) + + # Trim to actual number of flows + flow_sources = flow_sources[:n_flows_out] + flow_targets = flow_targets[:n_flows_out] + flow_values = flow_values[:n_flows_out] + + return flow_sources, flow_targets, flow_values, cost, alpha, beta, result_code \ No newline at end of file diff --git a/ot/lp/sparse_bipartitegraph.h b/ot/lp/sparse_bipartitegraph.h new file mode 100644 index 000000000..7ba13b41a --- /dev/null +++ b/ot/lp/sparse_bipartitegraph.h @@ -0,0 +1,281 @@ +/* -*- mode: C++; indent-tabs-mode: nil; -*- + * + * Sparse bipartite graph for optimal transport + * Only stores edges that are explicitly added (not all n1×n2 edges) + * + * Uses CSR (Compressed Sparse Row) format for better cache locality and performance + * - Binary search for arc lookup: O(log k) where k = avg edges per node + * - Compact memory layout for better cache performance + * - Requires edges to be provided in sorted order during construction + */ + +#pragma once + +#include "core.h" +#include +#include +#include + +namespace lemon { + + class SparseBipartiteDigraphBase { + public: + + typedef SparseBipartiteDigraphBase Digraph; + typedef int Node; + typedef int64_t Arc; + + protected: + + int _node_num; + int64_t _arc_num; + int _n1, _n2; + + std::vector _arc_sources; // _arc_sources[arc_id] = source node + std::vector _arc_targets; // _arc_targets[arc_id] = target node + + // CSR format + // _row_ptr[i] = start index in _col_indices for source node i + // _row_ptr[i+1] - _row_ptr[i] = number of outgoing edges from node i + std::vector _row_ptr; + std::vector _col_indices; + std::vector _arc_ids; + + mutable std::vector> _in_arcs; // _in_arcs[node] = incoming arc IDs + mutable bool _in_arcs_built; + + SparseBipartiteDigraphBase() : _node_num(0), _arc_num(0), _n1(0), _n2(0), _in_arcs_built(false) {} + + void construct(int n1, int n2) { + _node_num = n1 + n2; + _n1 = n1; + _n2 = n2; + _arc_num = 0; + _arc_sources.clear(); + _arc_targets.clear(); + _row_ptr.clear(); + _col_indices.clear(); + _arc_ids.clear(); + _in_arcs.clear(); + _in_arcs_built = false; + } + + void build_in_arcs() const { + if (_in_arcs_built) return; + + _in_arcs.resize(_node_num); + + for (Arc a = 0; a < _arc_num; ++a) { + Node tgt = _arc_targets[a]; + _in_arcs[tgt].push_back(a); + } + + _in_arcs_built = true; + } + + public: + + Node operator()(int ix) const { return Node(ix); } + static int index(const Node& node) { return node; } + + void buildFromEdges(const std::vector>& edges) { + _arc_num = edges.size(); + + if (_arc_num == 0) { + _row_ptr.assign(_n1 + 1, 0); + return; + } + + // Create indexed edges: (source, target, original_arc_id) + std::vector> indexed_edges; + indexed_edges.reserve(_arc_num); + for (Arc i = 0; i < _arc_num; ++i) { + indexed_edges.emplace_back(edges[i].first, edges[i].second, i); + } + + // Sort by source node, then by target node CSR requirement + std::sort(indexed_edges.begin(), indexed_edges.end(), + [](const auto& a, const auto& b) { + if (std::get<0>(a) != std::get<0>(b)) + return std::get<0>(a) < std::get<0>(b); + return std::get<1>(a) < std::get<1>(b); + }); + + _arc_sources.resize(_arc_num); + _arc_targets.resize(_arc_num); + _col_indices.resize(_arc_num); + _arc_ids.resize(_arc_num); + _row_ptr.resize(_n1 + 1); + + _row_ptr[0] = 0; + int current_row = 0; + + for (int64_t i = 0; i < _arc_num; ++i) { + Node src = std::get<0>(indexed_edges[i]); + Node tgt = std::get<1>(indexed_edges[i]); + Arc orig_arc_id = std::get<2>(indexed_edges[i]); + + // Fill out row_ptr for rows with no outgoing edges + while (current_row < src) { + _row_ptr[++current_row] = i; + } + + _arc_sources[orig_arc_id] = src; + _arc_targets[orig_arc_id] = tgt; + _col_indices[i] = tgt; + _arc_ids[i] = orig_arc_id; + } + + // Fill remaining row_ptr entries + while (current_row < _n1) { + _row_ptr[++current_row] = _arc_num; + } + + _in_arcs_built = false; + } + + // Find arc from s to t using binary search (returns -1 if not found) + Arc arc(const Node& s, const Node& t) const { + if (s < 0 || s >= _n1 || t < _n1 || t >= _node_num) { + return Arc(-1); + } + + int64_t start = _row_ptr[s]; + int64_t end = _row_ptr[s + 1]; + + // Binary search for target t in col_indices[start:end] + auto it = std::lower_bound( + _col_indices.begin() + start, + _col_indices.begin() + end, + t + ); + + if (it != _col_indices.begin() + end && *it == t) { + int64_t pos = it - _col_indices.begin(); + return _arc_ids[pos]; + } + + return Arc(-1); + } + + int nodeNum() const { return _node_num; } + int64_t arcNum() const { return _arc_num; } + + int maxNodeId() const { return _node_num - 1; } + int64_t maxArcId() const { return _arc_num - 1; } + + Node source(Arc arc) const { + return (arc >= 0 && arc < _arc_num) ? _arc_sources[arc] : Node(-1); + } + + Node target(Arc arc) const { + return (arc >= 0 && arc < _arc_num) ? _arc_targets[arc] : Node(-1); + } + + static int id(Node node) { return node; } + static int64_t id(Arc arc) { return arc; } + + static Node nodeFromId(int id) { return Node(id); } + static Arc arcFromId(int64_t id) { return Arc(id); } + + Arc findArc(Node s, Node t, Arc prev = -1) const { + return prev == -1 ? arc(s, t) : Arc(-1); + } + + void first(Node& node) const { + node = _node_num - 1; + } + + static void next(Node& node) { + --node; + } + + void first(Arc& arc) const { + arc = _arc_num - 1; + } + + static void next(Arc& arc) { + --arc; + } + + void firstOut(Arc& arc, const Node& node) const { + if (node < 0 || node >= _n1) { + arc = -1; + return; + } + + int64_t start = _row_ptr[node]; + int64_t end = _row_ptr[node + 1]; + + arc = (start < end) ? _arc_ids[start] : Arc(-1); + } + + void nextOut(Arc& arc) const { + if (arc < 0) return; + + Node src = _arc_sources[arc]; + int64_t start = _row_ptr[src]; + int64_t end = _row_ptr[src + 1]; + + for (int64_t i = start; i < end; ++i) { + if (_arc_ids[i] == arc) { + arc = (i + 1 < end) ? _arc_ids[i + 1] : Arc(-1); + return; + } + } + arc = -1; + } + + void firstIn(Arc& arc, const Node& node) const { + build_in_arcs(); // Lazy build on first call + + if (node < 0 || node >= _node_num || node < _n1) { + arc = -1; // Invalid node or source nodes have no incoming arcs + return; + } + + const std::vector& in = _in_arcs[node]; + arc = in.empty() ? Arc(-1) : in[0]; + } + + void nextIn(Arc& arc) const { + if (arc < 0) return; + + Node tgt = _arc_targets[arc]; + const std::vector& in = _in_arcs[tgt]; + + // Find current arc in the list and return next one + for (size_t i = 0; i < in.size(); ++i) { + if (in[i] == arc) { + arc = (i + 1 < in.size()) ? in[i + 1] : Arc(-1); + return; + } + } + arc = -1; + } + }; + + /// Sparse bipartite digraph - only stores edges that are explicitly added + class SparseBipartiteDigraph : public SparseBipartiteDigraphBase { + typedef SparseBipartiteDigraphBase Parent; + + public: + + SparseBipartiteDigraph() { construct(0, 0); } + + SparseBipartiteDigraph(int n1, int n2) { construct(n1, n2); } + + Node operator()(int ix) const { return Parent::operator()(ix); } + static int index(const Node& node) { return Parent::index(node); } + + void buildFromEdges(const std::vector>& edges) { + Parent::buildFromEdges(edges); + } + + Arc arc(Node s, Node t) const { return Parent::arc(s, t); } + + int nodeNum() const { return Parent::nodeNum(); } + int64_t arcNum() const { return Parent::arcNum(); } + }; + +} //namespace lemon diff --git a/setup.py b/setup.py index acbe5aed9..c8cefb729 100644 --- a/setup.py +++ b/setup.py @@ -50,7 +50,19 @@ link_args += flags if sys.platform.startswith("darwin"): - compile_args.append("-stdlib=libc++") + # Only add -stdlib=libc++ for Clang, not GCC + # GCC uses libstdc++ by default and doesn't recognize -stdlib flag + import subprocess + try: + # Check if using clang + compiler = os.environ.get('CXX', 'c++') + version_output = subprocess.check_output([compiler, '--version'], stderr=subprocess.STDOUT).decode() + if 'clang' in version_output.lower(): + compile_args.append("-stdlib=libc++") + except Exception: + # If we can't determine, don't add the flag (safer for GCC) + pass + sdk_path = subprocess.check_output(["xcrun", "--show-sdk-path"]) os.environ["CFLAGS"] = '-isysroot "{}"'.format(sdk_path.rstrip().decode("utf-8")) diff --git a/test/test_ot.py b/test/test_ot.py index e8217d54d..691f77382 100644 --- a/test/test_ot.py +++ b/test/test_ot.py @@ -12,6 +12,7 @@ import ot from ot.datasets import make_1D_gauss as gauss from ot.backend import torch, tf, get_backend +from scipy.sparse import coo_matrix def test_emd_dimension_and_mass_mismatch(): @@ -914,6 +915,333 @@ def test_dual_variables(): assert constraint_violation.max() < 1e-8 +def test_emd_sparse_vs_dense(): + """Test that sparse and dense EMD solvers produce identical results. + + Uses augmented k-NN graph approach: first solves with dense solver to + identify needed edges, then compares both solvers on the same graph. + """ + n_source = 100 + n_target = 100 + k = 10 + + rng = np.random.RandomState(42) + + x_source = rng.randn(n_source, 2) + x_target = rng.randn(n_target, 2) + 0.5 + + a = ot.utils.unif(n_source) + b = ot.utils.unif(n_target) + + C = ot.dist(x_source, x_target) + + rows = [] + cols = [] + data = [] + + for i in range(n_source): + distances = C[i, :] + nearest_k = np.argpartition(distances, k)[:k] + for j in nearest_k: + rows.append(i) + cols.append(j) + data.append(C[i, j]) + + C_knn = coo_matrix((data, (rows, cols)), shape=(n_source, n_target)) + + large_cost = 1e8 + C_dense_infty = np.full((n_source, n_target), large_cost) + C_knn_array = C_knn.toarray() + C_dense_infty[C_knn_array > 0] = C_knn_array[C_knn_array > 0] + + G_dense_initial = ot.emd(a, b, C_dense_infty) + eps = 1e-9 + active_mask = G_dense_initial > eps + knn_mask = C_knn_array > 0 + extra_edges_mask = active_mask & ~knn_mask + + rows_aug = [] + cols_aug = [] + data_aug = [] + + knn_rows, knn_cols = np.where(knn_mask) + for i, j in zip(knn_rows, knn_cols): + rows_aug.append(i) + cols_aug.append(j) + data_aug.append(C[i, j]) + + extra_rows, extra_cols = np.where(extra_edges_mask) + for i, j in zip(extra_rows, extra_cols): + rows_aug.append(i) + cols_aug.append(j) + data_aug.append(C[i, j]) + + C_augmented = coo_matrix( + (data_aug, (rows_aug, cols_aug)), shape=(n_source, n_target) + ) + + C_augmented_dense = np.full((n_source, n_target), large_cost) + C_augmented_array = C_augmented.toarray() + C_augmented_dense[C_augmented_array > 0] = C_augmented_array[C_augmented_array > 0] + + G_dense, log_dense = ot.emd(a, b, C_augmented_dense, log=True) + G_sparse, log_sparse = ot.emd(a, b, C_augmented, log=True) + + cost_dense = log_dense["cost"] + cost_sparse = log_sparse["cost"] + + np.testing.assert_allclose(cost_dense, cost_sparse, rtol=1e-5, atol=1e-7) + + # For dense, G_dense is returned; for sparse, reconstruct from flow edges + np.testing.assert_allclose(a, G_dense.sum(1), rtol=1e-5, atol=1e-7) + np.testing.assert_allclose(b, G_dense.sum(0), rtol=1e-5, atol=1e-7) + + # Reconstruct sparse matrix from flow for marginal checks + if G_sparse is None: + G_sparse_reconstructed = np.zeros((n_source, n_target)) + G_sparse_reconstructed[ + log_sparse["flow_sources"], log_sparse["flow_targets"] + ] = log_sparse["flow_values"] + np.testing.assert_allclose( + a, G_sparse_reconstructed.sum(1), rtol=1e-5, atol=1e-7 + ) + np.testing.assert_allclose( + b, G_sparse_reconstructed.sum(0), rtol=1e-5, atol=1e-7 + ) + else: + np.testing.assert_allclose(a, G_sparse.sum(1), rtol=1e-5, atol=1e-7) + np.testing.assert_allclose(b, G_sparse.sum(0), rtol=1e-5, atol=1e-7) + + +def test_emd2_sparse_vs_dense(): + """Test that sparse and dense emd2 solvers produce identical results. + + Uses augmented k-NN graph approach: first solves with dense solver to + identify needed edges, then compares both solvers on the same graph. + """ + n_source = 100 + n_target = 100 + k = 10 + + rng = np.random.RandomState(42) + + x_source = rng.randn(n_source, 2) + x_target = rng.randn(n_target, 2) + 0.5 + + a = ot.utils.unif(n_source) + b = ot.utils.unif(n_target) + + C = ot.dist(x_source, x_target) + + rows = [] + cols = [] + data = [] + + for i in range(n_source): + distances = C[i, :] + nearest_k = np.argpartition(distances, k)[:k] + for j in nearest_k: + rows.append(i) + cols.append(j) + data.append(C[i, j]) + + C_knn = coo_matrix((data, (rows, cols)), shape=(n_source, n_target)) + + large_cost = 1e8 + C_dense_infty = np.full((n_source, n_target), large_cost) + C_knn_array = C_knn.toarray() + C_dense_infty[C_knn_array > 0] = C_knn_array[C_knn_array > 0] + + G_dense_initial = ot.emd(a, b, C_dense_infty) + + eps = 1e-9 + active_mask = G_dense_initial > eps + knn_mask = C_knn_array > 0 + extra_edges_mask = active_mask & ~knn_mask + + rows_aug = [] + cols_aug = [] + data_aug = [] + + knn_rows, knn_cols = np.where(knn_mask) + for i, j in zip(knn_rows, knn_cols): + rows_aug.append(i) + cols_aug.append(j) + data_aug.append(C[i, j]) + + extra_rows, extra_cols = np.where(extra_edges_mask) + for i, j in zip(extra_rows, extra_cols): + rows_aug.append(i) + cols_aug.append(j) + data_aug.append(C[i, j]) + + C_augmented = coo_matrix( + (data_aug, (rows_aug, cols_aug)), shape=(n_source, n_target) + ) + + C_augmented_dense = np.full((n_source, n_target), large_cost) + C_augmented_array = C_augmented.toarray() + C_augmented_dense[C_augmented_array > 0] = C_augmented_array[C_augmented_array > 0] + + cost_dense = ot.emd2(a, b, C_augmented_dense) + cost_sparse = ot.emd2(a, b, C_augmented) + + np.testing.assert_allclose(cost_dense, cost_sparse, rtol=1e-5, atol=1e-7) + + +def test_emd_sparse_backends(nx): + """Test that sparse EMD works with different backends for weights a and b. + + Uses augmented k-NN graph approach to ensure feasibility. + """ + n_source = 50 + n_target = 50 + k = 10 + + rng = np.random.RandomState(42) + + a = ot.utils.unif(n_source) + b = ot.utils.unif(n_target) + + x_source = rng.randn(n_source, 2) + x_target = rng.randn(n_target, 2) + 0.5 + C = ot.dist(x_source, x_target) + + rows = [] + cols = [] + data = [] + + for i in range(n_source): + distances = C[i, :] + nearest_k = np.argpartition(distances, k)[:k] + for j in nearest_k: + rows.append(i) + cols.append(j) + data.append(C[i, j]) + + C_knn = coo_matrix((data, (rows, cols)), shape=(n_source, n_target)) + + # Augment with necessary edges (same approach as test_emd_sparse_vs_dense) + large_cost = 1e8 + C_dense_infty = np.full((n_source, n_target), large_cost) + C_knn_array = C_knn.toarray() + C_dense_infty[C_knn_array > 0] = C_knn_array[C_knn_array > 0] + + G_dense_initial = ot.emd(a, b, C_dense_infty) + eps = 1e-9 + active_mask = G_dense_initial > eps + knn_mask = C_knn_array > 0 + extra_edges_mask = active_mask & ~knn_mask + + rows_aug = [] + cols_aug = [] + data_aug = [] + + knn_rows, knn_cols = np.where(knn_mask) + for i, j in zip(knn_rows, knn_cols): + rows_aug.append(i) + cols_aug.append(j) + data_aug.append(C[i, j]) + + extra_rows, extra_cols = np.where(extra_edges_mask) + for i, j in zip(extra_rows, extra_cols): + rows_aug.append(i) + cols_aug.append(j) + data_aug.append(C[i, j]) + + C_augmented = coo_matrix( + (data_aug, (rows_aug, cols_aug)), shape=(n_source, n_target) + ) + + _, log_np = ot.emd(a, b, C_augmented, log=True) + + ab, bb = nx.from_numpy(a, b) + _, log_backend = ot.emd(ab, bb, C_augmented, log=True) + + cost_np = log_np["cost"] + cost_backend = nx.to_numpy(log_backend["cost"]) + + np.testing.assert_allclose(cost_np, cost_backend, rtol=1e-5, atol=1e-7) + + np.testing.assert_allclose( + log_np["flow_values"], log_backend["flow_values"], rtol=1e-5, atol=1e-7 + ) + + +def test_emd2_sparse_backends(nx): + """Test that sparse emd2 works with different backends for weights a and b. + + Uses augmented k-NN graph approach to ensure feasibility. + """ + n_source = 50 + n_target = 50 + k = 10 + + rng = np.random.RandomState(42) + + a = ot.utils.unif(n_source) + b = ot.utils.unif(n_target) + + x_source = rng.randn(n_source, 2) + x_target = rng.randn(n_target, 2) + 0.5 + C = ot.dist(x_source, x_target) + + rows = [] + cols = [] + data = [] + + for i in range(n_source): + distances = C[i, :] + nearest_k = np.argpartition(distances, k)[:k] + for j in nearest_k: + rows.append(i) + cols.append(j) + data.append(C[i, j]) + + C_knn = coo_matrix((data, (rows, cols)), shape=(n_source, n_target)) + + # Augment with necessary edges (same approach as test_emd2_sparse_vs_dense) + large_cost = 1e8 + C_dense_infty = np.full((n_source, n_target), large_cost) + C_knn_array = C_knn.toarray() + C_dense_infty[C_knn_array > 0] = C_knn_array[C_knn_array > 0] + + G_dense_initial = ot.emd(a, b, C_dense_infty) + eps = 1e-9 + active_mask = G_dense_initial > eps + knn_mask = C_knn_array > 0 + extra_edges_mask = active_mask & ~knn_mask + + rows_aug = [] + cols_aug = [] + data_aug = [] + + knn_rows, knn_cols = np.where(knn_mask) + for i, j in zip(knn_rows, knn_cols): + rows_aug.append(i) + cols_aug.append(j) + data_aug.append(C[i, j]) + + extra_rows, extra_cols = np.where(extra_edges_mask) + for i, j in zip(extra_rows, extra_cols): + rows_aug.append(i) + cols_aug.append(j) + data_aug.append(C[i, j]) + + C_augmented = coo_matrix( + (data_aug, (rows_aug, cols_aug)), shape=(n_source, n_target) + ) + + cost_np = ot.emd2(a, b, C_augmented) + + ab, bb = nx.from_numpy(a, b) + cost_backend = ot.emd2(ab, bb, C_augmented) + + cost_backend_np = nx.to_numpy(cost_backend) + + np.testing.assert_allclose(cost_np, cost_backend_np, rtol=1e-5, atol=1e-7) + + def check_duality_gap(a, b, M, G, u, v, cost): cost_dual = np.vdot(a, u) + np.vdot(b, v) # Check that dual and primal cost are equal