diff --git a/python/pyspark/rddsampler.py b/python/pyspark/rddsampler.py index 55e247da0e4dc..cf7a80acea88d 100644 --- a/python/pyspark/rddsampler.py +++ b/python/pyspark/rddsampler.py @@ -22,24 +22,32 @@ class RDDSamplerBase(object): def __init__(self, withReplacement, seed=None): - try: - import numpy - self._use_numpy = True - except ImportError: - print >> sys.stderr, ( - "NumPy does not appear to be installed. " - "Falling back to default random generator for sampling.") - self._use_numpy = False - self._seed = seed if seed is not None else random.randint(0, sys.maxint) self._withReplacement = withReplacement self._random = None self._split = None self._rand_initialized = False + self._tried_numpy = False + try: + import numpy + self._driver_has_numpy = True + except ImportError: + self._driver_has_numpy = False def initRandomGenerator(self, split): + if not self._tried_numpy: + self._use_numpy = False + if self._driver_has_numpy: + try: + import numpy + self._use_numpy = True + except ImportError: + print >> sys.stderr, ( + "NumPy does not appear to be installed. " + "Falling back to default random generator for sampling.") + self._tried_numpy = True + if self._use_numpy: - import numpy self._random = numpy.random.RandomState(self._seed) else: self._random = random.Random(self._seed)