|
9 | 9 | from torchtrain.datasets.tokenizer import TokenizerIf |
10 | 10 |
|
11 | 11 | from datasets import load_dataset |
| 12 | +from datasets.distributed import split_dataset_by_node |
12 | 13 |
|
13 | 14 |
|
14 | 15 | class AlpacaDataset(IterableDataset): |
@@ -44,32 +45,24 @@ def __init__( |
44 | 45 | rank: int = 0, |
45 | 46 | **kwargs |
46 | 47 | ) -> None: |
47 | | - self._data = load_dataset("tatsu-lab/alpaca", split="train") |
| 48 | + # TODO: This is a temporary solution for small datasets like Alpaca. |
| 49 | + # For larger datasets we need to use a more scalable approach. |
| 50 | + # Setting `streaming=True` works for large dataset, but the speed is slow. |
| 51 | + ds = load_dataset("tatsu-lab/alpaca", split="train") |
| 52 | + self.data_iterator = iter(split_dataset_by_node(ds, rank, world_size)) |
48 | 53 | self._tokenizer = tokenizer |
49 | | - self.data_iterator = iter(self._data) |
50 | 54 | self.seq_len = seq_len |
51 | | - self.world_size = world_size |
52 | | - self.rank = rank |
53 | | - self.response_tag = "\n\n### Response:\n" |
54 | | - |
55 | | - def __len__(self): |
56 | | - return len(self._data) |
57 | 55 |
|
58 | 56 | def __iter__(self): |
59 | 57 | max_buffer_token_len = 1 + self.seq_len |
60 | 58 | all_tokens: List[int] = [] |
61 | 59 |
|
62 | | - for idx, sample in enumerate(self.data_iterator): |
63 | | - # select samples to pack in a round-robin fashion |
64 | | - # TODO: This is a temporary solution for small datasets like Alpaca. |
65 | | - # For larger datasets we need to use a more scalable approach. |
66 | | - if idx % self.world_size != self.rank: |
67 | | - continue |
| 60 | + for sample in self.data_iterator: |
68 | 61 | sample_text = sample["text"] |
69 | 62 | sample_tokens = self._tokenizer.encode(sample_text, bos=True, eos=True) |
70 | 63 | all_tokens.extend(sample_tokens) |
71 | 64 |
|
72 | | - if len(all_tokens) >= max_buffer_token_len: |
| 65 | + while len(all_tokens) >= max_buffer_token_len: |
73 | 66 | x = torch.LongTensor(all_tokens[:max_buffer_token_len]) |
74 | 67 | # batched_x = x.reshape(self.batch_size, -1) |
75 | 68 | # update tokens to the remaining tokens |
|
0 commit comments