Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions MANIFEST.in
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,5 @@ include ot/lp/full_bipartitegraph.h
include ot/lp/full_bipartitegraph_omp.h
include ot/lp/network_simplex_simple.h
include ot/lp/network_simplex_simple_omp.h
include ot/lp/sparse_bipartitegraph.h
include ot/partial/partial_cython.pyx
18 changes: 18 additions & 0 deletions ot/lp/EMD.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,24 @@ enum ProblemType {
int EMD_wrap(int n1,int n2, double *X, double *Y,double *D, double *G, double* alpha, double* beta, double *cost, uint64_t maxIter);
int EMD_wrap_omp(int n1,int n2, double *X, double *Y,double *D, double *G, double* alpha, double* beta, double *cost, uint64_t maxIter, int numThreads);

int EMD_wrap_sparse(
int n1,
int n2,
double *X,
double *Y,
uint64_t n_edges, // Number of edges in sparse graph
uint64_t *edge_sources, // Source indices for each edge (n_edges)
uint64_t *edge_targets, // Target indices for each edge (n_edges)
double *edge_costs, // Cost for each edge (n_edges)
uint64_t *flow_sources_out, // Output: source indices of non-zero flows
uint64_t *flow_targets_out, // Output: target indices of non-zero flows
double *flow_values_out, // Output: flow values
uint64_t *n_flows_out,
double *alpha, // Output: dual variables for sources (n1)
double *beta, // Output: dual variables for targets (n2)
double *cost, // Output: total transportation cost
uint64_t maxIter // Maximum iterations for solver
);


#endif
155 changes: 155 additions & 0 deletions ot/lp/EMD_wrapper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,10 @@

#include "network_simplex_simple.h"
#include "network_simplex_simple_omp.h"
#include "sparse_bipartitegraph.h"
#include "EMD.h"
#include <cstdint>
#include <unordered_map>


int EMD_wrap(int n1, int n2, double *X, double *Y, double *D, double *G,
Expand Down Expand Up @@ -216,3 +218,156 @@ int EMD_wrap_omp(int n1, int n2, double *X, double *Y, double *D, double *G,

return ret;
}

// ============================================================================
// SPARSE VERSION: Accepts edge list instead of dense cost matrix
// ============================================================================
int EMD_wrap_sparse(
int n1,
int n2,
double *X,
double *Y,
uint64_t n_edges,
uint64_t *edge_sources,
uint64_t *edge_targets,
double *edge_costs,
uint64_t *flow_sources_out,
uint64_t *flow_targets_out,
double *flow_values_out,
uint64_t *n_flows_out,
double *alpha,
double *beta,
double *cost,
uint64_t maxIter
) {
using namespace lemon;

uint64_t n = 0;
for (int i = 0; i < n1; i++) {
double val = *(X + i);
if (val > 0) {
n++;
} else if (val < 0) {
return INFEASIBLE;
}
}

uint64_t m = 0;
for (int i = 0; i < n2; i++) {
double val = *(Y + i);
if (val > 0) {
m++;
} else if (val < 0) {
return INFEASIBLE;
}
}

std::vector<uint64_t> indI(n); // indI[graph_idx] = original_source_idx
std::vector<uint64_t> indJ(m); // indJ[graph_idx] = original_target_idx
std::vector<double> weights1(n); // Source masses (positive only)
std::vector<double> weights2(m); // Target masses (negative for demand)

// Create reverse mapping: original_idx → graph_idx
std::vector<int64_t> source_to_graph(n1, -1);
std::vector<int64_t> target_to_graph(n2, -1);

uint64_t cur = 0;
for (int i = 0; i < n1; i++) {
double val = *(X + i);
if (val > 0) {
weights1[cur] = val; // Store the mass
indI[cur] = i; // Forward map: graph → original
source_to_graph[i] = cur; // Reverse map: original → graph
cur++;
}
}

cur = 0;
for (int i = 0; i < n2; i++) {
double val = *(Y + i);
if (val > 0) {
weights2[cur] = -val;
indJ[cur] = i; // Forward map: graph → original
target_to_graph[i] = cur; // Reverse map: original → graph
cur++;
}
}

typedef SparseBipartiteDigraph Digraph;
DIGRAPH_TYPEDEFS(Digraph);

Digraph di(n, m);

std::vector<std::pair<int, int>> edges; // (source, target) pairs
std::vector<uint64_t> edge_to_arc; // edge_to_arc[k] = arc ID for edge k
std::vector<double> arc_costs; // arc_costs[arc_id] = cost (for O(1) lookup)
edges.reserve(n_edges);
edge_to_arc.reserve(n_edges);

uint64_t valid_edge_count = 0;
for (uint64_t k = 0; k < n_edges; k++) {
int64_t src_orig = edge_sources[k];
int64_t tgt_orig = edge_targets[k];
int64_t src = source_to_graph[src_orig];
int64_t tgt = target_to_graph[tgt_orig];

if (src >= 0 && tgt >= 0) {
edges.emplace_back(src, tgt + n);
edge_to_arc.push_back(valid_edge_count);
arc_costs.push_back(edge_costs[k]); // Store cost indexed by arc ID
valid_edge_count++;
} else {
edge_to_arc.push_back(UINT64_MAX);
}
}


di.buildFromEdges(edges);

NetworkSimplexSimple<Digraph, double, double, node_id_type> net(
di, true, (int)(n + m), di.arcNum(), maxIter
);

net.supplyMap(&weights1[0], (int)n, &weights2[0], (int)m);

for (uint64_t k = 0; k < n_edges; k++) {
if (edge_to_arc[k] != UINT64_MAX) {
net.setCost(edge_to_arc[k], edge_costs[k]);
}
}

int ret = net.run();

if (ret == (int)net.OPTIMAL || ret == (int)net.MAX_ITER_REACHED) {
*cost = 0;
*n_flows_out = 0;

Arc a;
di.first(a);
for (; a != INVALID; di.next(a)) {
uint64_t i = di.source(a);
uint64_t j = di.target(a);
double flow = net.flow(a);

uint64_t orig_i = indI[i];
uint64_t orig_j = indJ[j - n];


double arc_cost = arc_costs[a];

*cost += flow * arc_cost;


*(alpha + orig_i) = -net.potential(i);
*(beta + orig_j) = net.potential(j);

if (flow > 1e-15) {
flow_sources_out[*n_flows_out] = orig_i;
flow_targets_out[*n_flows_out] = orig_j;
flow_values_out[*n_flows_out] = flow;
(*n_flows_out)++;
}
}
}
return ret;
}
Loading
Loading