Skip to content

Commit 55a6b0b

Browse files
committed
modify data split to use HF api
ghstack-source-id: e23d5e0 Pull Request resolved: #65
1 parent 4dbe0af commit 55a6b0b

File tree

1 file changed

+8
-15
lines changed

1 file changed

+8
-15
lines changed

torchtrain/datasets/alpaca.py

Lines changed: 8 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from torchtrain.datasets.tokenizer import TokenizerIf
1010

1111
from datasets import load_dataset
12+
from datasets.distributed import split_dataset_by_node
1213

1314

1415
class AlpacaDataset(IterableDataset):
@@ -44,32 +45,24 @@ def __init__(
4445
rank: int = 0,
4546
**kwargs
4647
) -> None:
47-
self._data = load_dataset("tatsu-lab/alpaca", split="train")
48+
# TODO: This is a temporary solution for small datasets like Alpaca.
49+
# For larger datasets we need to use a more scalable approach.
50+
# Setting `streaming=True` works for large dataset, but the speed is slow.
51+
ds = load_dataset("tatsu-lab/alpaca", split="train")
52+
self.data_iterator = iter(split_dataset_by_node(ds, rank, world_size))
4853
self._tokenizer = tokenizer
49-
self.data_iterator = iter(self._data)
5054
self.seq_len = seq_len
51-
self.world_size = world_size
52-
self.rank = rank
53-
self.response_tag = "\n\n### Response:\n"
54-
55-
def __len__(self):
56-
return len(self._data)
5755

5856
def __iter__(self):
5957
max_buffer_token_len = 1 + self.seq_len
6058
all_tokens: List[int] = []
6159

62-
for idx, sample in enumerate(self.data_iterator):
63-
# select samples to pack in a round-robin fashion
64-
# TODO: This is a temporary solution for small datasets like Alpaca.
65-
# For larger datasets we need to use a more scalable approach.
66-
if idx % self.world_size != self.rank:
67-
continue
60+
for sample in self.data_iterator:
6861
sample_text = sample["text"]
6962
sample_tokens = self._tokenizer.encode(sample_text, bos=True, eos=True)
7063
all_tokens.extend(sample_tokens)
7164

72-
if len(all_tokens) >= max_buffer_token_len:
65+
while len(all_tokens) >= max_buffer_token_len:
7366
x = torch.LongTensor(all_tokens[:max_buffer_token_len])
7467
# batched_x = x.reshape(self.batch_size, -1)
7568
# update tokens to the remaining tokens

0 commit comments

Comments
 (0)