-
Notifications
You must be signed in to change notification settings - Fork 529
[MRG] OpenMP support #260
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[MRG] OpenMP support #260
Changes from all commits
a3d7107
692960e
403f74b
b5f11e2
d4c1b3d
d6e5953
c84285a
b2c958d
1679a3e
ad5c0c1
88f6aa9
99e51ed
b6b26cc
c6892f4
829f181
b052fc2
3f68dc2
42b2f42
937fec6
fd25881
dd0c8a5
dc9a475
78c0934
47aedd2
8a2bbf1
7d45569
0258f85
7aee5fc
2dde576
7c565ad
f86d406
9df20c5
b744a01
e3e4c37
78e4ca4
71cda54
8ca091f
d2359c5
bfbb582
0491bfe
8e86752
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,85 @@ | ||
"""Helpers for OpenMP support during the build.""" | ||
|
||
# This code is adapted for a large part from the astropy openmp helpers, which | ||
# can be found at: https://github.com/astropy/extension-helpers/blob/master/extension_helpers/_openmp_helpers.py # noqa | ||
|
||
|
||
import os | ||
import sys | ||
import textwrap | ||
import subprocess | ||
|
||
from distutils.errors import CompileError, LinkError | ||
|
||
from pre_build_helpers import compile_test_program | ||
|
||
|
||
def get_openmp_flag(compiler): | ||
"""Get openmp flags for a given compiler""" | ||
|
||
if hasattr(compiler, 'compiler'): | ||
compiler = compiler.compiler[0] | ||
else: | ||
compiler = compiler.__class__.__name__ | ||
|
||
if sys.platform == "win32" and ('icc' in compiler or 'icl' in compiler): | ||
omp_flag = ['/Qopenmp'] | ||
elif sys.platform == "win32": | ||
omp_flag = ['/openmp'] | ||
elif sys.platform in ("darwin", "linux") and "icc" in compiler: | ||
omp_flag = ['-qopenmp'] | ||
elif sys.platform == "darwin" and 'openmp' in os.getenv('CPPFLAGS', ''): | ||
omp_flag = [] | ||
else: | ||
# Default flag for GCC and clang: | ||
omp_flag = ['-fopenmp'] | ||
if sys.platform.startswith("darwin"): | ||
omp_flag += ["-Xpreprocessor", "-lomp"] | ||
return omp_flag | ||
|
||
|
||
def check_openmp_support(): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. test this function also |
||
"""Check whether OpenMP test code can be compiled and run""" | ||
|
||
code = textwrap.dedent( | ||
"""\ | ||
#include <omp.h> | ||
#include <stdio.h> | ||
int main(void) { | ||
#pragma omp parallel | ||
printf("nthreads=%d\\n", omp_get_num_threads()); | ||
return 0; | ||
} | ||
""") | ||
|
||
extra_preargs = os.getenv('LDFLAGS', None) | ||
if extra_preargs is not None: | ||
extra_preargs = extra_preargs.strip().split(" ") | ||
extra_preargs = [ | ||
flag for flag in extra_preargs | ||
if flag.startswith(('-L', '-Wl,-rpath', '-l'))] | ||
|
||
extra_postargs = get_openmp_flag | ||
|
||
try: | ||
output, compile_flags = compile_test_program( | ||
code, | ||
extra_preargs=extra_preargs, | ||
extra_postargs=extra_postargs | ||
) | ||
|
||
if output and 'nthreads=' in output[0]: | ||
nthreads = int(output[0].strip().split('=')[1]) | ||
openmp_supported = len(output) == nthreads | ||
elif "PYTHON_CROSSENV" in os.environ: | ||
# Since we can't run the test program when cross-compiling | ||
# assume that openmp is supported if the program can be | ||
# compiled. | ||
openmp_supported = True | ||
else: | ||
openmp_supported = False | ||
|
||
except (CompileError, LinkError, subprocess.CalledProcessError): | ||
openmp_supported = False | ||
compile_flags = [] | ||
return openmp_supported, compile_flags |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,87 @@ | ||
"""Helpers to check build environment before actual build of POT""" | ||
|
||
import os | ||
import sys | ||
import glob | ||
import tempfile | ||
import setuptools # noqa | ||
import subprocess | ||
|
||
from distutils.dist import Distribution | ||
from distutils.sysconfig import customize_compiler | ||
from numpy.distutils.ccompiler import new_compiler | ||
from numpy.distutils.command.config_compiler import config_cc | ||
|
||
|
||
def _get_compiler(): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. test this function or at least the function that is calling it |
||
"""Get a compiler equivalent to the one that will be used to build POT | ||
Handles compiler specified as follows: | ||
- python setup.py build_ext --compiler=<compiler> | ||
- CC=<compiler> python setup.py build_ext | ||
""" | ||
dist = Distribution({'script_name': os.path.basename(sys.argv[0]), | ||
'script_args': sys.argv[1:], | ||
'cmdclass': {'config_cc': config_cc}}) | ||
|
||
cmd_opts = dist.command_options.get('build_ext') | ||
if cmd_opts is not None and 'compiler' in cmd_opts: | ||
compiler = cmd_opts['compiler'][1] | ||
else: | ||
compiler = None | ||
|
||
ccompiler = new_compiler(compiler=compiler) | ||
customize_compiler(ccompiler) | ||
|
||
return ccompiler | ||
|
||
|
||
def compile_test_program(code, extra_preargs=[], extra_postargs=[]): | ||
"""Check that some C code can be compiled and run""" | ||
ccompiler = _get_compiler() | ||
|
||
# extra_(pre/post)args can be a callable to make it possible to get its | ||
# value from the compiler | ||
if callable(extra_preargs): | ||
extra_preargs = extra_preargs(ccompiler) | ||
if callable(extra_postargs): | ||
extra_postargs = extra_postargs(ccompiler) | ||
|
||
start_dir = os.path.abspath('.') | ||
|
||
with tempfile.TemporaryDirectory() as tmp_dir: | ||
try: | ||
os.chdir(tmp_dir) | ||
|
||
# Write test program | ||
with open('test_program.c', 'w') as f: | ||
f.write(code) | ||
|
||
os.mkdir('objects') | ||
|
||
# Compile, test program | ||
ccompiler.compile(['test_program.c'], output_dir='objects', | ||
extra_postargs=extra_postargs) | ||
|
||
# Link test program | ||
objects = glob.glob( | ||
os.path.join('objects', '*' + ccompiler.obj_extension)) | ||
ccompiler.link_executable(objects, 'test_program', | ||
extra_preargs=extra_preargs, | ||
extra_postargs=extra_postargs) | ||
|
||
if "PYTHON_CROSSENV" not in os.environ: | ||
# Run test program if not cross compiling | ||
# will raise a CalledProcessError if return code was non-zero | ||
output = subprocess.check_output('./test_program') | ||
output = output.decode( | ||
sys.stdout.encoding or 'utf-8').splitlines() | ||
else: | ||
# Return an empty output if we are cross compiling | ||
# as we cannot run the test_program | ||
output = [] | ||
except Exception: | ||
raise | ||
finally: | ||
os.chdir(start_dir) | ||
|
||
return output, extra_postargs |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -12,16 +12,22 @@ | |
* | ||
*/ | ||
|
||
|
||
#include "network_simplex_simple.h" | ||
#include "network_simplex_simple_omp.h" | ||
#include "EMD.h" | ||
#include <cstdint> | ||
|
||
|
||
int EMD_wrap(int n1, int n2, double *X, double *Y, double *D, double *G, | ||
double* alpha, double* beta, double *cost, int maxIter) { | ||
// beware M and C anre strored in row major C style!!! | ||
int n, m, i, cur; | ||
// beware M and C are stored in row major C style!!! | ||
|
||
using namespace lemon; | ||
int n, m, cur; | ||
|
||
typedef FullBipartiteDigraph Digraph; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. remove the typedef if you don't use it anymore? |
||
DIGRAPH_TYPEDEFS(FullBipartiteDigraph); | ||
DIGRAPH_TYPEDEFS(Digraph); | ||
|
||
// Get the number of non zero coordinates for r and c | ||
n=0; | ||
|
@@ -48,7 +54,7 @@ int EMD_wrap(int n1, int n2, double *X, double *Y, double *D, double *G, | |
std::vector<int> indI(n), indJ(m); | ||
std::vector<double> weights1(n), weights2(m); | ||
Digraph di(n, m); | ||
NetworkSimplexSimple<Digraph,double,double, node_id_type> net(di, true, n+m, n*m, maxIter); | ||
NetworkSimplexSimple<Digraph,double,double, node_id_type> net(di, true, n+m, ((int64_t)n)*((int64_t)m), maxIter); | ||
|
||
// Set supply and demand, don't account for 0 values (faster) | ||
|
||
|
@@ -76,23 +82,26 @@ int EMD_wrap(int n1, int n2, double *X, double *Y, double *D, double *G, | |
net.supplyMap(&weights1[0], n, &weights2[0], m); | ||
|
||
// Set the cost of each edge | ||
int64_t idarc = 0; | ||
for (int i=0; i<n; i++) { | ||
for (int j=0; j<m; j++) { | ||
double val=*(D+indI[i]*n2+indJ[j]); | ||
net.setCost(di.arcFromId(i*m+j), val); | ||
net.setCost(di.arcFromId(idarc), val); | ||
++idarc; | ||
} | ||
} | ||
|
||
|
||
// Solve the problem with the network simplex algorithm | ||
|
||
int ret=net.run(); | ||
int i, j; | ||
if (ret==(int)net.OPTIMAL || ret==(int)net.MAX_ITER_REACHED) { | ||
*cost = 0; | ||
Arc a; di.first(a); | ||
for (; a != INVALID; di.next(a)) { | ||
int i = di.source(a); | ||
int j = di.target(a); | ||
i = di.source(a); | ||
j = di.target(a); | ||
double flow = net.flow(a); | ||
*cost += flow * (*(D+indI[i]*n2+indJ[j-n])); | ||
*(G+indI[i]*n2+indJ[j-n]) = flow; | ||
|
@@ -106,3 +115,104 @@ int EMD_wrap(int n1, int n2, double *X, double *Y, double *D, double *G, | |
return ret; | ||
} | ||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
int EMD_wrap_omp(int n1, int n2, double *X, double *Y, double *D, double *G, | ||
double* alpha, double* beta, double *cost, int maxIter, int numThreads) { | ||
// beware M and C are stored in row major C style!!! | ||
|
||
using namespace lemon_omp; | ||
int n, m, cur; | ||
|
||
typedef FullBipartiteDigraph Digraph; | ||
DIGRAPH_TYPEDEFS(Digraph); | ||
|
||
// Get the number of non zero coordinates for r and c | ||
n=0; | ||
for (int i=0; i<n1; i++) { | ||
double val=*(X+i); | ||
if (val>0) { | ||
n++; | ||
}else if(val<0){ | ||
return INFEASIBLE; | ||
} | ||
} | ||
m=0; | ||
for (int i=0; i<n2; i++) { | ||
double val=*(Y+i); | ||
if (val>0) { | ||
m++; | ||
}else if(val<0){ | ||
return INFEASIBLE; | ||
} | ||
} | ||
|
||
// Define the graph | ||
|
||
std::vector<int> indI(n), indJ(m); | ||
std::vector<double> weights1(n), weights2(m); | ||
Digraph di(n, m); | ||
NetworkSimplexSimple<Digraph,double,double, node_id_type> net(di, true, n+m, ((int64_t)n)*((int64_t)m), maxIter, numThreads); | ||
|
||
// Set supply and demand, don't account for 0 values (faster) | ||
|
||
cur=0; | ||
for (int i=0; i<n1; i++) { | ||
double val=*(X+i); | ||
if (val>0) { | ||
weights1[ cur ] = val; | ||
indI[cur++]=i; | ||
} | ||
} | ||
|
||
// Demand is actually negative supply... | ||
|
||
cur=0; | ||
for (int i=0; i<n2; i++) { | ||
double val=*(Y+i); | ||
if (val>0) { | ||
weights2[ cur ] = -val; | ||
indJ[cur++]=i; | ||
} | ||
} | ||
|
||
|
||
net.supplyMap(&weights1[0], n, &weights2[0], m); | ||
|
||
// Set the cost of each edge | ||
int64_t idarc = 0; | ||
for (int i=0; i<n; i++) { | ||
for (int j=0; j<m; j++) { | ||
double val=*(D+indI[i]*n2+indJ[j]); | ||
net.setCost(di.arcFromId(idarc), val); | ||
++idarc; | ||
} | ||
} | ||
|
||
|
||
// Solve the problem with the network simplex algorithm | ||
|
||
int ret=net.run(); | ||
int i, j; | ||
if (ret==(int)net.OPTIMAL || ret==(int)net.MAX_ITER_REACHED) { | ||
*cost = 0; | ||
Arc a; di.first(a); | ||
for (; a != INVALID; di.next(a)) { | ||
i = di.source(a); | ||
j = di.target(a); | ||
double flow = net.flow(a); | ||
*cost += flow * (*(D+indI[i]*n2+indJ[j-n])); | ||
*(G+indI[i]*n2+indJ[j-n]) = flow; | ||
*(alpha + indI[i]) = -net.potential(i); | ||
*(beta + indJ[j-n]) = net.potential(j); | ||
} | ||
|
||
} | ||
|
||
|
||
return ret; | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add short documentation to this function and describe what it does.
The function should be tested