diff --git a/nibabel/streamlines/array_sequence.py b/nibabel/streamlines/array_sequence.py index d892b9b91d..91e94d7ac1 100644 --- a/nibabel/streamlines/array_sequence.py +++ b/nibabel/streamlines/array_sequence.py @@ -37,7 +37,7 @@ def __init__(self, arr_seq, common_shape, dtype): self.common_shape = common_shape n_in_row = reduce(mul, common_shape, 1) bytes_per_row = n_in_row * dtype.itemsize - self.rows_per_buf = bytes_per_row / self.bytes_per_buf + self.rows_per_buf = max(1, self.bytes_per_buf // bytes_per_row) def update_seq(self, arr_seq): arr_seq._offsets = np.array(self.offsets) @@ -185,6 +185,7 @@ def finalize_append(self): return self._build_cache.update_seq(self) self._build_cache = None + self.shrink_data() def _resize_data_to(self, n_rows, build_cache): """ Resize data array if required """ diff --git a/nibabel/streamlines/tests/test_array_sequence.py b/nibabel/streamlines/tests/test_array_sequence.py index 6925f58b35..45f50075f8 100644 --- a/nibabel/streamlines/tests/test_array_sequence.py +++ b/nibabel/streamlines/tests/test_array_sequence.py @@ -2,6 +2,7 @@ import sys import unittest import tempfile +import itertools import numpy as np from nose.tools import assert_equal, assert_raises, assert_true @@ -91,11 +92,20 @@ def test_creating_arraysequence_from_list(self): SEQ_DATA['data']) def test_creating_arraysequence_from_generator(self): - gen = (e for e in SEQ_DATA['data']) - check_arr_seq(ArraySequence(gen), SEQ_DATA['data']) + gen_1, gen_2 = itertools.tee((e for e in SEQ_DATA['data'])) + seq = ArraySequence(gen_1) + seq_with_buffer = ArraySequence(gen_2, buffer_size=256) + + # Check buffer size effect + assert_equal(seq_with_buffer.data.shape, seq.data.shape) + assert_true(seq_with_buffer._buffer_size > seq._buffer_size) + + # Check generator result + check_arr_seq(seq, SEQ_DATA['data']) + check_arr_seq(seq_with_buffer, SEQ_DATA['data']) # Already consumed generator - check_empty_arr_seq(ArraySequence(gen)) + check_empty_arr_seq(ArraySequence(gen_1)) def test_creating_arraysequence_from_arraysequence(self): seq = ArraySequence(SEQ_DATA['data'])