From f788cb9c04f0c2a96b5002f0ecef1d1504132003 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marc-Alexandre=20C=C3=B4t=C3=A9?= Date: Mon, 7 Nov 2016 11:56:30 -0500 Subject: [PATCH 1/5] Add a function to concatenate multiple ArraySequences object given an axis. --- nibabel/streamlines/array_sequence.py | 29 +++++++++++++++++++ .../streamlines/tests/test_array_sequence.py | 17 ++++++++++- 2 files changed, 45 insertions(+), 1 deletion(-) diff --git a/nibabel/streamlines/array_sequence.py b/nibabel/streamlines/array_sequence.py index 94c8aaf004..0312dabb69 100644 --- a/nibabel/streamlines/array_sequence.py +++ b/nibabel/streamlines/array_sequence.py @@ -318,6 +318,7 @@ def __getitem__(self, idx): raise TypeError("Index must be either an int, a slice, a list of int" " or a ndarray of bool! Not " + str(type(idx))) + def __iter__(self): if len(self._lengths) != len(self._offsets): raise ValueError("ArraySequence object corrupted:" @@ -380,3 +381,31 @@ def create_arraysequences_from_generator(gen, n): for seq in seqs: seq.finalize_append() return seqs + + +def concatenate(seqs, axis): + """ Concatenates multiple :class:`ArraySequence` objects along an axis. + + Parameters + ---------- + seqs: list of :class:`ArraySequence` objects + Sequences to concatenate. + axis : int + Axis along which the sequences will be concatenated. + + Returns + ------- + new_seq: :class:`ArraySequence` object + New :class:`ArraySequence` object which is the result of + concatenating multiple sequences along the given axis. + """ + new_seq = seqs[0].copy() + if axis == 0: + # This is the same as an extend. + for seq in seqs[1:]: + new_seq.extend(seq) + + return new_seq + + new_seq._data = np.concatenate([seq._data for seq in seqs], axis=axis) + return new_seq diff --git a/nibabel/streamlines/tests/test_array_sequence.py b/nibabel/streamlines/tests/test_array_sequence.py index a2ebd3a22e..903fefe1a9 100644 --- a/nibabel/streamlines/tests/test_array_sequence.py +++ b/nibabel/streamlines/tests/test_array_sequence.py @@ -8,7 +8,7 @@ from nibabel.testing import assert_arrays_equal from numpy.testing import assert_array_equal -from ..array_sequence import ArraySequence, is_array_sequence +from ..array_sequence import ArraySequence, is_array_sequence, concatenate, create_arraysequences_from_generator SEQ_DATA = {} @@ -299,3 +299,18 @@ def test_save_and_load_arraysequence(self): # Make sure we can add new elements to it. loaded_seq.append(SEQ_DATA['data'][0]) + + +def test_concatenate(): + seq = SEQ_DATA['seq'].copy() # In case there is in-place modification. + seqs = [seq[:, [i]] for i in range(seq.common_shape[0])] + new_seq = concatenate(seqs, axis=1) + check_arr_seq(new_seq, SEQ_DATA['data']) + assert_true(not new_seq._is_view) + + seq = SEQ_DATA['seq'].copy() # In case there is in-place modification. + seqs = [seq[:, [i]] for i in range(seq.common_shape[0])] + new_seq = concatenate(seqs, axis=0) + + assert_true(len(new_seq), seq.common_shape[0]*len(seq)) + assert_array_equal(new_seq._data, seq._data.T.reshape((-1, 1))) From 261bf3beb7a04d58dd0380474a60188094941522 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marc-Alexandre=20C=C3=B4t=C3=A9?= Date: Mon, 7 Nov 2016 12:13:33 -0500 Subject: [PATCH 2/5] Removed unused import --- nibabel/streamlines/tests/test_array_sequence.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nibabel/streamlines/tests/test_array_sequence.py b/nibabel/streamlines/tests/test_array_sequence.py index 903fefe1a9..10990efb0b 100644 --- a/nibabel/streamlines/tests/test_array_sequence.py +++ b/nibabel/streamlines/tests/test_array_sequence.py @@ -8,7 +8,7 @@ from nibabel.testing import assert_arrays_equal from numpy.testing import assert_array_equal -from ..array_sequence import ArraySequence, is_array_sequence, concatenate, create_arraysequences_from_generator +from ..array_sequence import ArraySequence, is_array_sequence, concatenate SEQ_DATA = {} From 410740156b49a99c19f0a14fee229bdd0e47b32a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marc-Alexandre=20C=C3=B4t=C3=A9?= Date: Mon, 7 Nov 2016 12:33:40 -0500 Subject: [PATCH 3/5] PEP8 --- nibabel/streamlines/array_sequence.py | 1 - 1 file changed, 1 deletion(-) diff --git a/nibabel/streamlines/array_sequence.py b/nibabel/streamlines/array_sequence.py index 0312dabb69..dbfcb7a526 100644 --- a/nibabel/streamlines/array_sequence.py +++ b/nibabel/streamlines/array_sequence.py @@ -318,7 +318,6 @@ def __getitem__(self, idx): raise TypeError("Index must be either an int, a slice, a list of int" " or a ndarray of bool! Not " + str(type(idx))) - def __iter__(self): if len(self._lengths) != len(self._offsets): raise ValueError("ArraySequence object corrupted:" From a3aedc73c517aff13c3b64f375fd2c642e71e946 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marc-Alexandre=20C=C3=B4t=C3=A9?= Date: Thu, 19 Jan 2017 18:02:19 -0500 Subject: [PATCH 4/5] Addressed @matthew-brett's comments --- nibabel/streamlines/array_sequence.py | 2 +- nibabel/streamlines/tests/test_array_sequence.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/nibabel/streamlines/array_sequence.py b/nibabel/streamlines/array_sequence.py index dbfcb7a526..d892b9b91d 100644 --- a/nibabel/streamlines/array_sequence.py +++ b/nibabel/streamlines/array_sequence.py @@ -387,7 +387,7 @@ def concatenate(seqs, axis): Parameters ---------- - seqs: list of :class:`ArraySequence` objects + seqs: iterable of :class:`ArraySequence` objects Sequences to concatenate. axis : int Axis along which the sequences will be concatenated. diff --git a/nibabel/streamlines/tests/test_array_sequence.py b/nibabel/streamlines/tests/test_array_sequence.py index 10990efb0b..f41d50132a 100644 --- a/nibabel/streamlines/tests/test_array_sequence.py +++ b/nibabel/streamlines/tests/test_array_sequence.py @@ -305,12 +305,12 @@ def test_concatenate(): seq = SEQ_DATA['seq'].copy() # In case there is in-place modification. seqs = [seq[:, [i]] for i in range(seq.common_shape[0])] new_seq = concatenate(seqs, axis=1) + seq._data += 100 # Modifying the 'seq' shouldn't change 'new_seq'. check_arr_seq(new_seq, SEQ_DATA['data']) assert_true(not new_seq._is_view) - seq = SEQ_DATA['seq'].copy() # In case there is in-place modification. + seq = SEQ_DATA['seq'] seqs = [seq[:, [i]] for i in range(seq.common_shape[0])] new_seq = concatenate(seqs, axis=0) - assert_true(len(new_seq), seq.common_shape[0]*len(seq)) assert_array_equal(new_seq._data, seq._data.T.reshape((-1, 1))) From 09f5de5a1a8724e1e585fdc5583257abd50dba6d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marc-Alexandre=20C=C3=B4t=C3=A9?= Date: Tue, 14 Feb 2017 12:17:24 -0500 Subject: [PATCH 5/5] PEP8 --- nibabel/streamlines/tests/test_array_sequence.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nibabel/streamlines/tests/test_array_sequence.py b/nibabel/streamlines/tests/test_array_sequence.py index f41d50132a..42bd6ba49a 100644 --- a/nibabel/streamlines/tests/test_array_sequence.py +++ b/nibabel/streamlines/tests/test_array_sequence.py @@ -312,5 +312,5 @@ def test_concatenate(): seq = SEQ_DATA['seq'] seqs = [seq[:, [i]] for i in range(seq.common_shape[0])] new_seq = concatenate(seqs, axis=0) - assert_true(len(new_seq), seq.common_shape[0]*len(seq)) + assert_true(len(new_seq), seq.common_shape[0] * len(seq)) assert_array_equal(new_seq._data, seq._data.T.reshape((-1, 1)))