Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
a3d7107
Added : OpenMP support
Jun 11, 2021
692960e
Merge branch 'master' into OTOMP
kguerda-idris Jun 11, 2021
403f74b
Merge branch 'PythonOT:master' into OTOMP
ncassereau Jun 15, 2021
b5f11e2
Commit clean up
Jun 15, 2021
d4c1b3d
Number of CPUs correctly calculated on SLURM clusters
Jun 15, 2021
d6e5953
Corrected number of processes for cluster slurm
Jun 15, 2021
c84285a
Mistake corrected
Jun 15, 2021
b2c958d
parmap is now deprecated
Jul 2, 2021
1679a3e
Now a different solver is used depending on the requested number of t…
Jul 2, 2021
ad5c0c1
Tiny mistake corrected
Jul 2, 2021
88f6aa9
Folders are now in the ot library instead of at the root
Jul 2, 2021
99e51ed
Merge branch 'PythonOT:master' into OTOMP
ncassereau Jul 2, 2021
b6b26cc
Helpers is now correctly placed
Jul 2, 2021
c6892f4
Merge branch 'OTOMP' of https://github.com/kguerda-idris/POT into OTOMP
Jul 2, 2021
829f181
Attempt to make compilation work smoothly
Jul 5, 2021
b052fc2
OS compatible path
Jul 5, 2021
3f68dc2
NumThreads now defaults to 1
Jul 5, 2021
42b2f42
Better flags
Jul 5, 2021
937fec6
Mistake corrected in case of OpenMP unavailability
Jul 5, 2021
fd25881
Revert OpenMP flags modification, which do not compile on Windows
Jul 5, 2021
dd0c8a5
Merge branch 'master' into OTOMP
rflamary Sep 6, 2021
dc9a475
Test helper functions
Sep 10, 2021
78c0934
Helpers comments
Sep 10, 2021
47aedd2
Documentation update
Sep 10, 2021
8a2bbf1
File title corrected
Sep 10, 2021
7d45569
Merge branch 'master' into OTOMP
rflamary Sep 13, 2021
0258f85
Warning no longer using print
Sep 15, 2021
7aee5fc
Merge branch 'OTOMP' of https://github.com/kguerda-idris/POT into OTOMP
Sep 15, 2021
2dde576
Last attempt for macos compilation
Sep 15, 2021
7c565ad
pls work
Sep 15, 2021
f86d406
atempt
Sep 16, 2021
9df20c5
solving a type error
Sep 16, 2021
b744a01
TypeError OpenMP
Sep 16, 2021
e3e4c37
Compilation finally working on Windows
Sep 16, 2021
78e4ca4
Merge branch 'master' into OTOMP
rflamary Sep 20, 2021
71cda54
Bug solve, number of threads now correctly selected
Sep 21, 2021
8ca091f
Merge branch 'OTOMP' of https://github.com/kguerda-idris/POT into OTOMP
Sep 21, 2021
d2359c5
64 bits solver to avoid overflows for bigger problems
Sep 24, 2021
bfbb582
64 bits EMD corrected
Sep 24, 2021
0491bfe
build wheels
Sep 24, 2021
8e86752
Merge branch 'master' into OTOMP
rflamary Sep 28, 2021
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 85 additions & 0 deletions ot/helpers/openmp_helpers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
"""Helpers for OpenMP support during the build."""

# This code is adapted for a large part from the astropy openmp helpers, which
# can be found at: https://github.com/astropy/extension-helpers/blob/master/extension_helpers/_openmp_helpers.py # noqa


import os
import sys
import textwrap
import subprocess

from distutils.errors import CompileError, LinkError

from pre_build_helpers import compile_test_program


def get_openmp_flag(compiler):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add short documentation to this function and describe what it does.

The function should be tested

"""Get openmp flags for a given compiler"""

if hasattr(compiler, 'compiler'):
compiler = compiler.compiler[0]
else:
compiler = compiler.__class__.__name__

if sys.platform == "win32" and ('icc' in compiler or 'icl' in compiler):
omp_flag = ['/Qopenmp']
elif sys.platform == "win32":
omp_flag = ['/openmp']
elif sys.platform in ("darwin", "linux") and "icc" in compiler:
omp_flag = ['-qopenmp']
elif sys.platform == "darwin" and 'openmp' in os.getenv('CPPFLAGS', ''):
omp_flag = []
else:
# Default flag for GCC and clang:
omp_flag = ['-fopenmp']
if sys.platform.startswith("darwin"):
omp_flag += ["-Xpreprocessor", "-lomp"]
return omp_flag


def check_openmp_support():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

test this function also

"""Check whether OpenMP test code can be compiled and run"""

code = textwrap.dedent(
"""\
#include <omp.h>
#include <stdio.h>
int main(void) {
#pragma omp parallel
printf("nthreads=%d\\n", omp_get_num_threads());
return 0;
}
""")

extra_preargs = os.getenv('LDFLAGS', None)
if extra_preargs is not None:
extra_preargs = extra_preargs.strip().split(" ")
extra_preargs = [
flag for flag in extra_preargs
if flag.startswith(('-L', '-Wl,-rpath', '-l'))]

extra_postargs = get_openmp_flag

try:
output, compile_flags = compile_test_program(
code,
extra_preargs=extra_preargs,
extra_postargs=extra_postargs
)

if output and 'nthreads=' in output[0]:
nthreads = int(output[0].strip().split('=')[1])
openmp_supported = len(output) == nthreads
elif "PYTHON_CROSSENV" in os.environ:
# Since we can't run the test program when cross-compiling
# assume that openmp is supported if the program can be
# compiled.
openmp_supported = True
else:
openmp_supported = False

except (CompileError, LinkError, subprocess.CalledProcessError):
openmp_supported = False
compile_flags = []
return openmp_supported, compile_flags
87 changes: 87 additions & 0 deletions ot/helpers/pre_build_helpers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
"""Helpers to check build environment before actual build of POT"""

import os
import sys
import glob
import tempfile
import setuptools # noqa
import subprocess

from distutils.dist import Distribution
from distutils.sysconfig import customize_compiler
from numpy.distutils.ccompiler import new_compiler
from numpy.distutils.command.config_compiler import config_cc


def _get_compiler():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

test this function or at least the function that is calling it

"""Get a compiler equivalent to the one that will be used to build POT
Handles compiler specified as follows:
- python setup.py build_ext --compiler=<compiler>
- CC=<compiler> python setup.py build_ext
"""
dist = Distribution({'script_name': os.path.basename(sys.argv[0]),
'script_args': sys.argv[1:],
'cmdclass': {'config_cc': config_cc}})

cmd_opts = dist.command_options.get('build_ext')
if cmd_opts is not None and 'compiler' in cmd_opts:
compiler = cmd_opts['compiler'][1]
else:
compiler = None

ccompiler = new_compiler(compiler=compiler)
customize_compiler(ccompiler)

return ccompiler


def compile_test_program(code, extra_preargs=[], extra_postargs=[]):
"""Check that some C code can be compiled and run"""
ccompiler = _get_compiler()

# extra_(pre/post)args can be a callable to make it possible to get its
# value from the compiler
if callable(extra_preargs):
extra_preargs = extra_preargs(ccompiler)
if callable(extra_postargs):
extra_postargs = extra_postargs(ccompiler)

start_dir = os.path.abspath('.')

with tempfile.TemporaryDirectory() as tmp_dir:
try:
os.chdir(tmp_dir)

# Write test program
with open('test_program.c', 'w') as f:
f.write(code)

os.mkdir('objects')

# Compile, test program
ccompiler.compile(['test_program.c'], output_dir='objects',
extra_postargs=extra_postargs)

# Link test program
objects = glob.glob(
os.path.join('objects', '*' + ccompiler.obj_extension))
ccompiler.link_executable(objects, 'test_program',
extra_preargs=extra_preargs,
extra_postargs=extra_postargs)

if "PYTHON_CROSSENV" not in os.environ:
# Run test program if not cross compiling
# will raise a CalledProcessError if return code was non-zero
output = subprocess.check_output('./test_program')
output = output.decode(
sys.stdout.encoding or 'utf-8').splitlines()
else:
# Return an empty output if we are cross compiling
# as we cannot run the test_program
output = []
except Exception:
raise
finally:
os.chdir(start_dir)

return output, extra_postargs
5 changes: 2 additions & 3 deletions ot/lp/EMD.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,19 +18,18 @@

#include <iostream>
#include <vector>
#include "network_simplex_simple.h"

using namespace lemon;
typedef unsigned int node_id_type;

enum ProblemType {
INFEASIBLE,
OPTIMAL,
UNBOUNDED,
MAX_ITER_REACHED
MAX_ITER_REACHED
};

int EMD_wrap(int n1,int n2, double *X, double *Y,double *D, double *G, double* alpha, double* beta, double *cost, int maxIter);
int EMD_wrap_omp(int n1,int n2, double *X, double *Y,double *D, double *G, double* alpha, double* beta, double *cost, int maxIter, int numThreads);



Expand Down
124 changes: 117 additions & 7 deletions ot/lp/EMD_wrapper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,22 @@
*
*/


#include "network_simplex_simple.h"
#include "network_simplex_simple_omp.h"
#include "EMD.h"
#include <cstdint>


int EMD_wrap(int n1, int n2, double *X, double *Y, double *D, double *G,
double* alpha, double* beta, double *cost, int maxIter) {
// beware M and C anre strored in row major C style!!!
int n, m, i, cur;
// beware M and C are stored in row major C style!!!

using namespace lemon;
int n, m, cur;

typedef FullBipartiteDigraph Digraph;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove the typedef if you don't use it anymore?

DIGRAPH_TYPEDEFS(FullBipartiteDigraph);
DIGRAPH_TYPEDEFS(Digraph);

// Get the number of non zero coordinates for r and c
n=0;
Expand All @@ -48,7 +54,7 @@ int EMD_wrap(int n1, int n2, double *X, double *Y, double *D, double *G,
std::vector<int> indI(n), indJ(m);
std::vector<double> weights1(n), weights2(m);
Digraph di(n, m);
NetworkSimplexSimple<Digraph,double,double, node_id_type> net(di, true, n+m, n*m, maxIter);
NetworkSimplexSimple<Digraph,double,double, node_id_type> net(di, true, n+m, ((int64_t)n)*((int64_t)m), maxIter);

// Set supply and demand, don't account for 0 values (faster)

Expand Down Expand Up @@ -76,23 +82,26 @@ int EMD_wrap(int n1, int n2, double *X, double *Y, double *D, double *G,
net.supplyMap(&weights1[0], n, &weights2[0], m);

// Set the cost of each edge
int64_t idarc = 0;
for (int i=0; i<n; i++) {
for (int j=0; j<m; j++) {
double val=*(D+indI[i]*n2+indJ[j]);
net.setCost(di.arcFromId(i*m+j), val);
net.setCost(di.arcFromId(idarc), val);
++idarc;
}
}


// Solve the problem with the network simplex algorithm

int ret=net.run();
int i, j;
if (ret==(int)net.OPTIMAL || ret==(int)net.MAX_ITER_REACHED) {
*cost = 0;
Arc a; di.first(a);
for (; a != INVALID; di.next(a)) {
int i = di.source(a);
int j = di.target(a);
i = di.source(a);
j = di.target(a);
double flow = net.flow(a);
*cost += flow * (*(D+indI[i]*n2+indJ[j-n]));
*(G+indI[i]*n2+indJ[j-n]) = flow;
Expand All @@ -106,3 +115,104 @@ int EMD_wrap(int n1, int n2, double *X, double *Y, double *D, double *G,
return ret;
}







int EMD_wrap_omp(int n1, int n2, double *X, double *Y, double *D, double *G,
double* alpha, double* beta, double *cost, int maxIter, int numThreads) {
// beware M and C are stored in row major C style!!!

using namespace lemon_omp;
int n, m, cur;

typedef FullBipartiteDigraph Digraph;
DIGRAPH_TYPEDEFS(Digraph);

// Get the number of non zero coordinates for r and c
n=0;
for (int i=0; i<n1; i++) {
double val=*(X+i);
if (val>0) {
n++;
}else if(val<0){
return INFEASIBLE;
}
}
m=0;
for (int i=0; i<n2; i++) {
double val=*(Y+i);
if (val>0) {
m++;
}else if(val<0){
return INFEASIBLE;
}
}

// Define the graph

std::vector<int> indI(n), indJ(m);
std::vector<double> weights1(n), weights2(m);
Digraph di(n, m);
NetworkSimplexSimple<Digraph,double,double, node_id_type> net(di, true, n+m, ((int64_t)n)*((int64_t)m), maxIter, numThreads);

// Set supply and demand, don't account for 0 values (faster)

cur=0;
for (int i=0; i<n1; i++) {
double val=*(X+i);
if (val>0) {
weights1[ cur ] = val;
indI[cur++]=i;
}
}

// Demand is actually negative supply...

cur=0;
for (int i=0; i<n2; i++) {
double val=*(Y+i);
if (val>0) {
weights2[ cur ] = -val;
indJ[cur++]=i;
}
}


net.supplyMap(&weights1[0], n, &weights2[0], m);

// Set the cost of each edge
int64_t idarc = 0;
for (int i=0; i<n; i++) {
for (int j=0; j<m; j++) {
double val=*(D+indI[i]*n2+indJ[j]);
net.setCost(di.arcFromId(idarc), val);
++idarc;
}
}


// Solve the problem with the network simplex algorithm

int ret=net.run();
int i, j;
if (ret==(int)net.OPTIMAL || ret==(int)net.MAX_ITER_REACHED) {
*cost = 0;
Arc a; di.first(a);
for (; a != INVALID; di.next(a)) {
i = di.source(a);
j = di.target(a);
double flow = net.flow(a);
*cost += flow * (*(D+indI[i]*n2+indJ[j-n]));
*(G+indI[i]*n2+indJ[j-n]) = flow;
*(alpha + indI[i]) = -net.potential(i);
*(beta + indJ[j-n]) = net.potential(j);
}

}


return ret;
}
Loading