Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 76 additions & 5 deletions src/datasets/iterable_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -683,7 +683,6 @@ def __init__(
# if undersampling ("first_exhausted"), we stop as soon as one dataset is exhausted
# if oversampling ("all_exhausted"), we stop as soons as every dataset is exhausted, i.e as soon as every samples of every dataset has been visited at least once
self.bool_strategy_func = np.all if (stopping_strategy == "all_exhausted") else np.any
# TODO(QL): implement iter_arrow

@property
def is_typed(self):
Expand All @@ -693,6 +692,11 @@ def is_typed(self):
def features(self):
return self.ex_iterables[0].features

@property
def iter_arrow(self):
# Can iterate on arrow tables if all ex_iterables can iterate
return self._iter_arrow if all(ex_iterable.iter_arrow for ex_iterable in self.ex_iterables) else None

def _get_indices_iterator(self):
# this is an infinite iterator to keep track of which iterator we want to pick examples from
ex_iterable_idx = self._state_dict["ex_iterable_idx"] if self._state_dict else 0
Expand All @@ -712,6 +716,48 @@ def _init_state_dict(self) -> dict:
}
return self._state_dict

def _iter_arrow(self):
# we use this to buffer one example of each iterator to know if an iterator is exhausted
nexts = [None] * len(self.ex_iterables)
# because of that, we need to rewind 1 example when reloading the state dict
if self._state_dict:
for i in range(len(self.ex_iterables)):
if self._state_dict["previous_states"][i] is not None:
self.ex_iterables[i].load_state_dict(self._state_dict["previous_states"][i])
iterators = [ex_iterable.iter_arrow() for ex_iterable in self.ex_iterables]

indices_iterator = self._get_indices_iterator()

is_exhausted = (
np.array(self._state_dict["is_exhausted"]) if self._state_dict else np.full(len(self.ex_iterables), False)
)
for i in indices_iterator:
# if the stopping criteria is met, break the main for loop
if self.bool_strategy_func(is_exhausted):
break
# let's pick one example from the iterator at index i
if nexts[i] is None:
nexts[i] = next(iterators[i], False)
result = nexts[i]
if self._state_dict:
self._state_dict["previous_states"][i] = deepcopy(self._state_dict["ex_iterables"][i])
nexts[i] = next(iterators[i], False)

# the iterator is exhausted
if nexts[i] is False:
is_exhausted[i] = True
if self._state_dict:
self._state_dict["is_exhausted"][i] = True
# we reset it in case the stopping crtieria isn't met yet
nexts[i] = None
if self._state_dict:
self._state_dict["ex_iterables"][i] = self.ex_iterables[i]._init_state_dict()
self._state_dict["previous_states"][i] = None
iterators[i] = self.ex_iterables[i].iter_arrow()

if result is not False:
yield result

def __iter__(self):
# we use this to buffer one example of each iterator to know if an iterator is exhausted
nexts = [None] * len(self.ex_iterables)
Expand Down Expand Up @@ -1524,7 +1570,6 @@ def __init__(self, ex_iterable: _BaseExamplesIterable, buffer_size: int, generat
self.ex_iterable = ex_iterable
self.buffer_size = buffer_size
self.generator = generator
# TODO(QL): implement iter_arrow

@property
def is_typed(self):
Expand All @@ -1534,6 +1579,10 @@ def is_typed(self):
def features(self):
return self.ex_iterable.features

@property
def iter_arrow(self):
return self._iter_arrow if self.ex_iterable.iter_arrow else None

def _init_state_dict(self) -> dict:
self._state_dict = self.ex_iterable._init_state_dict()
self._original_state_dict = self.state_dict()
Expand Down Expand Up @@ -1570,6 +1619,23 @@ def __iter__(self):
rng.shuffle(mem_buffer)
yield from mem_buffer

def _iter_arrow(self):
buffer_size = self.buffer_size
rng = deepcopy(self.generator)
indices_iterator = self._iter_random_indices(rng, buffer_size)
# this is the shuffle buffer that we keep in memory
mem_buffer = []
for key, pa_table in self.ex_iterable.iter_arrow():
if len(mem_buffer) == buffer_size: # if the buffer is full, pick and example from it
i = next(indices_iterator)
yield mem_buffer[i]
mem_buffer[i] = (key, pa_table) # replace the picked example by a new one
else: # otherwise, keep filling the buffer
mem_buffer.append((key, pa_table))
# when we run out of examples, we shuffle the remaining examples in the buffer and yield them
rng.shuffle(mem_buffer)
yield from mem_buffer

def shuffle_data_sources(self, generator: np.random.Generator) -> "BufferShuffledExamplesIterable":
"""Shuffle the wrapped examples iterable as well as the shuffling buffer."""
return BufferShuffledExamplesIterable(
Expand Down Expand Up @@ -2870,8 +2936,12 @@ def shuffle(
generator = deepcopy(generator)
shuffling = ShufflingConfig(generator=generator, _original_seed=seed)
return IterableDataset(
ex_iterable=BufferShuffledExamplesIterable(
self._ex_iterable, buffer_size=buffer_size, generator=generator
BufferShuffledExamplesIterable(
RebatchedArrowExamplesIterable(self._ex_iterable, batch_size=1)
if self._ex_iterable.iter_arrow
else self._ex_iterable,
buffer_size=buffer_size,
generator=generator,
),
info=self._info.copy(),
split=self._split,
Expand Down Expand Up @@ -4458,7 +4528,8 @@ def _interleave_iterable_datasets(
)

ex_iterables = [copy.deepcopy(d._ex_iterable) for d in datasets]

if all(ex_iterable.iter_arrow for ex_iterable in ex_iterables):
ex_iterables = [RebatchedArrowExamplesIterable(ex_iterable, batch_size=1) for ex_iterable in ex_iterables]
# Use cycling or random cycling of sources
if probabilities is None:
ex_iterable = CyclingMultiSourcesExamplesIterable(ex_iterables, stopping_strategy=stopping_strategy)
Expand Down
Loading