Skip to content
49 changes: 26 additions & 23 deletions metafunctions/concurrent.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from metafunctions.core import inject_call_state
from metafunctions import exceptions

_no_value = object()

class ConcurrentMerge(FunctionMerge):
def __init__(self, function_merge: FunctionMerge):
Expand All @@ -26,36 +25,32 @@ def __init__(self, function_merge: FunctionMerge):
function_merge._merge_func,
function_merge._functions,
function_merge._function_join_str)
self._function_merge = function_merge

def __str__(self):
joined_funcs = super().__str__()
return f"concurrent{joined_funcs}"
merge_name = str(self._function_merge)
return f'concurrent{merge_name}' if merge_name.startswith('(') else f'concurrent({merge_name})'

@inject_call_state
def __call__(self, *args, **kwargs):
'''We fork here, and execute each function in a child process before joining the results
with _merge_func
'''
arg_iter, func_iter = self._get_call_iterators(args)
enumerated_funcs = enumerate(func_iter)
result_q = Queue()
error_q = Queue()

#spawn a child for each function
children = []
for i, (arg, f) in enumerate(zip(arg_iter, func_iter)):
pid = os.fork()
if not pid:
#we are the child
self._process_and_die(i, f, result_q, error_q, kwargs, arg)
children.append(pid)
for arg, (i, f) in zip(arg_iter, enumerated_funcs):
child_pid = self._process_in_fork(i, f, result_q, error_q, (arg, ), kwargs)
children.append(child_pid)

#iterate over any remaining functions for which we have no args
for j, f in enumerate(func_iter, i+1):
pid = os.fork()
if not pid:
#we are the child
self._process_and_die(j, f, result_q, error_q, kwargs)
children.append(pid)
for i, f in enumerated_funcs:
child_pid = self._process_in_fork(i, f, result_q, error_q, (), kwargs)
children.append(child_pid)

#the parent waits for all children to complete
for pid in children:
Expand All @@ -74,16 +69,24 @@ def __call__(self, *args, **kwargs):

return self._merge_func(*results)

@staticmethod
def _process_and_die(idx, func, result_q, error_q, kwargs, arg=_no_value):
'''This function is only called by child processes. Call the given function with the given
args and kwargs, put the result in result_q, then die.
def _get_call_iterators(self, args):
return self._function_merge._get_call_iterators(args)

def _call_function(self, f, args:tuple, kwargs:dict):
return self._function_merge._call_function(f, args, kwargs)

def _process_in_fork(self, idx, func, result_q, error_q, args, kwargs):
'''Call self._call_function in a child process. This function returns the ID of the child
in the parent process, while the child process calls _call_function, puts the results in
the provided queues, then exits.
'''
pid = os.fork()
if pid:
return pid

#here we are the child
try:
if arg is _no_value:
r = func(**kwargs)
else:
r = func(arg, **kwargs)
r = self._call_function(func, args, kwargs)
except Exception as e:
error_q.put(e)
else:
Expand Down
46 changes: 26 additions & 20 deletions metafunctions/core/_base.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import operator
import typing as tp
import abc
import itertools
Expand All @@ -8,7 +7,7 @@
from metafunctions.core._decorators import binary_operation
from metafunctions.core._decorators import inject_call_state
from metafunctions.core._call_state import CallState
from metafunctions.operators import concat
from metafunctions import operators
from metafunctions import exceptions


Expand Down Expand Up @@ -63,43 +62,43 @@ def __ror__(self, other):

@binary_operation
def __and__(self, other):
return FunctionMerge.combine(concat, self, other)
return FunctionMerge.combine(operators.concat, self, other)

@binary_operation
def __rand__(self, other):
return FunctionMerge.combine(concat, other, self)
return FunctionMerge.combine(operators.concat, other, self)

@binary_operation
def __add__(self, other):
return FunctionMerge(operator.add, (self, other))
return FunctionMerge(operators.add, (self, other))

@binary_operation
def __radd__(self, other):
return FunctionMerge(operator.add, (other, self))
return FunctionMerge(operators.add, (other, self))

@binary_operation
def __sub__(self, other):
return FunctionMerge(operator.sub, (self, other))
return FunctionMerge(operators.sub, (self, other))

@binary_operation
def __rsub__(self, other):
return FunctionMerge(operator.sub, (other, self))
return FunctionMerge(operators.sub, (other, self))

@binary_operation
def __mul__(self, other):
return FunctionMerge(operator.mul, (self, other))
return FunctionMerge(operators.mul, (self, other))

@binary_operation
def __rmul__(self, other):
return FunctionMerge(operator.mul, (other, self))
return FunctionMerge(operators.mul, (other, self))

@binary_operation
def __truediv__(self, other):
return FunctionMerge(operator.truediv, (self, other))
return FunctionMerge(operators.truediv, (self, other))

@binary_operation
def __rtruediv__(self, other):
return FunctionMerge(operator.truediv, (other, self))
return FunctionMerge(operators.truediv, (other, self))

@binary_operation
def __matmul__(self, other):
Expand Down Expand Up @@ -147,11 +146,11 @@ def combine(cls, *funcs):

class FunctionMerge(MetaFunction):
_character_to_operator = {
'+': operator.add,
'-': operator.sub,
'*': operator.mul,
'/': operator.truediv,
'&': concat,
'+': operators.add,
'-': operators.sub,
'*': operators.mul,
'/': operators.truediv,
'&': operators.concat,
}
_operator_to_character = {v: k for k, v in _character_to_operator.items()}

Expand Down Expand Up @@ -191,10 +190,10 @@ def __call__(self, *args, **kwargs):
# second, the first will be advanced one extra time, because zip has already called next()
# on the first iterator before discovering that the second has been exhausted.
for arg, f in zip(args_iter, func_iter):
results.append(f(arg, **kwargs))
results.append(self._call_function(f, (arg, ), kwargs))

#Any extra functions are called with no input
results.extend([f(**kwargs) for f in func_iter])
results.extend([self._call_function(f, (), kwargs) for f in func_iter])
return self._merge_func(*results)

def __repr__(self):
Expand All @@ -211,7 +210,7 @@ def combine(cls, merge_func: tp.Callable, *funcs, function_join_str=None):
'''
new_funcs = []
for f in funcs:
if isinstance(f, cls) and f._merge_func == merge_func:
if isinstance(f, cls) and f._merge_func is merge_func:
new_funcs.extend(f.functions)
else:
new_funcs.append(f)
Expand All @@ -233,6 +232,13 @@ def _get_call_iterators(self, args):

return args_iter, func_iter

def _call_function(self, f, args:tuple, kwargs:dict):
'''This function receives one function, and the args and kwargs that should be used to call
that function. It returns the result of the function call. This gets its own method so that
subclasses can customize its behaviour.
'''
return f(*args, **kwargs)


class SimpleFunction(MetaFunction):
def __init__(self, function, name=None, print_location_in_traceback=True):
Expand Down
37 changes: 37 additions & 0 deletions metafunctions/map.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import typing as tp
import itertools

from metafunctions.concurrent import FunctionMerge
from metafunctions.operators import concat


class MergeMap(FunctionMerge):
def __init__(self, function:tp.Callable, merge_function:tp.Callable=concat):
'''
MergeMap is a FunctionMerge with only one function. When called, it behaves like the
builtin `map` function and calls its function once per item in the iterable(s) it receives.
'''
super().__init__(merge_function, (function, ))

def _get_call_iterators(self, args):
'''
Each element in args is an iterable.
'''
args_iter = zip(*args)

# Note that EVERY element in the func iter will be called, so we need to make sure the
# length of our iterator is the same as the shortest iterable we received.
shortest_arg = min(args, key=len)
func_iter = itertools.repeat(self.functions[0], len(shortest_arg))
return args_iter, func_iter

def _call_function(self, f, args:tuple, kwargs:dict):
'''In MergeMap, args will be a single element tuple containing the args for this function.
'''
return f(*args[0], **kwargs)

def __str__(self):
return f'mmap({self.functions[0]!s})'

def __repr__(self):
return f'{self.__class__.__name__}({self.functions[0]}, merge_function={self._merge_func})'
1 change: 1 addition & 0 deletions metafunctions/operators.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
'''
Extra operators used by MetaFunctions
'''
from operator import add, sub, truediv, mul

def concat(*args):
"concat(1, 2, 3) -> (1, 2, 3)"
Expand Down
41 changes: 41 additions & 0 deletions metafunctions/tests/test_concurrent.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,10 @@
from metafunctions.util import bind_call_state
from metafunctions.util import highlight_current_function
from metafunctions.util import concurrent
from metafunctions.util import mmap
from metafunctions.util import star
from metafunctions.concurrent import ConcurrentMerge
from metafunctions import operators
from metafunctions.exceptions import ConcurrentException, CompositionError, CallError


Expand Down Expand Up @@ -98,9 +101,47 @@ def test_not_concurrent(self):

def test_str_repr(self):
cab = ConcurrentMerge(a + b)
cmap = concurrent(mmap(a))

self.assertEqual(repr(cab), f'ConcurrentMerge({operator.add}, ({repr(a)}, {repr(b)}))')
self.assertEqual(str(cab), f'concurrent(a + b)')
self.assertEqual(str(cmap), f'concurrent(mmap(a))')

def test_basic_map(self):
# We can upgrade maps to run in parallel
banana = 'bnn' | concurrent(mmap(a)) | ''.join
str_concat = operators.concat | node(''.join)
batman = concurrent(mmap(a, operator=str_concat))
self.assertEqual(banana(), 'banana')
self.assertEqual(batman('nnnn'), 'nananana')

def test_multi_arg_map(self):
@node
def f(*args):
return args

m = concurrent(mmap(f))

with self.assertRaises(CompositionError):
#Because star returns a simple function, we can't upgrade it.
starmap = concurrent(star(mmap(f)))
#we have to wrap concurrent in star instead.
starmap = star(concurrent(mmap(f)))

mapstar = concurrent(mmap(star(f)))

self.assertEqual(m([1, 2, 3], [4, 5, 6]), ((1, 4), (2, 5), (3, 6)))
self.assertEqual(m([1, 2, 3]), ((1, ), (2, ), (3, )))

with self.assertRaises(TypeError):
self.assertEqual(starmap([1, 2, 3]))
self.assertEqual(starmap([[1, 2, 3]]), m([1, 2, 3]))

cmp = ([1, 2, 3], [4, 5, 6]) | starmap
self.assertEqual(cmp(), ((1, 4), (2, 5), (3, 6)))

cmp = ([1, 2, 3], [4, 5, 6]) | mapstar
self.assertEqual(cmp(), ((1, 2, 3), (4, 5, 6)))

### Simple Sample Functions ###
@node
Expand Down
58 changes: 58 additions & 0 deletions metafunctions/tests/test_map.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@

from metafunctions.tests.util import BaseTestCase
from metafunctions.util import node
from metafunctions.util import star
from metafunctions.util import mmap
from metafunctions.map import MergeMap
from metafunctions import operators

class TestIntegration(BaseTestCase):
def test_basic(self):
banana = 'bnn' | MergeMap(a) | ''.join
str_concat = operators.concat | node(''.join)
batman = MergeMap(a, merge_function=str_concat)
self.assertEqual(banana(), 'banana')
self.assertEqual(batman('nnnn'), 'nananana')

def test_multi_arg(self):
@node
def f(*args):
return args

m = mmap(f)
starmap = star(mmap(f))
mapstar = mmap(star(f))

self.assertEqual(m([1, 2, 3], [4, 5, 6]), ((1, 4), (2, 5), (3, 6)))
self.assertEqual(m([1, 2, 3]), ((1, ), (2, ), (3, )))

with self.assertRaises(TypeError):
self.assertEqual(starmap([1, 2, 3]))
self.assertEqual(starmap([[1, 2, 3]]), m([1, 2, 3]))

cmp = ([1, 2, 3], [4, 5, 6]) | starmap
self.assertEqual(cmp(), ((1, 4), (2, 5), (3, 6)))

cmp = ([1, 2, 3], [4, 5, 6]) | mapstar
self.assertEqual(cmp(), ((1, 2, 3), (4, 5, 6)))

def test_auto_meta(self):
mapsum = mmap(sum)
self.assertEqual(mapsum([[1, 2], [3, 4]]), (3, 7))
self.assertEqual(str(mapsum), f'mmap(sum)')

def test_str_repr(self):
m = MergeMap(a)
self.assertEqual(str(m), 'mmap(a)')
self.assertEqual(repr(m), f'MergeMap(a, merge_function={operators.concat})')


@node
def a(x):
return x + 'a'
@node
def b(x):
return x + 'b'
@node
def c(x):
return x + 'c'
Loading