Skip to content

Commit bd6fe55

Browse files
committed
support infinite loop over alpaca dataset
ghstack-source-id: 38cbc27 Pull Request resolved: #92
1 parent ae85e97 commit bd6fe55

File tree

2 files changed

+28
-16
lines changed

2 files changed

+28
-16
lines changed

torchtrain/datasets/alpaca.py

Lines changed: 25 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ class AlpacaDataset(IterableDataset):
2020
seq_len (int): max sequence length
2121
world_size (int): number of data parallel processes participating in training
2222
rank (int): rank of the current data parallel process
23+
infinite: whether to loop infinitely over the dataset
2324
2425
Data input format:
2526
{
@@ -43,38 +44,47 @@ def __init__(
4344
seq_len: int = 2048,
4445
world_size: int = 1,
4546
rank: int = 0,
47+
infinite: bool = False,
4648
**kwargs
4749
) -> None:
4850
# TODO: This is a temporary solution for small datasets like Alpaca.
4951
# For larger datasets we need to use a more scalable approach.
5052
# Setting `streaming=True` works for large dataset, but the speed is slow.
5153
ds = load_dataset("tatsu-lab/alpaca", split="train")
52-
self.data_iterator = iter(split_dataset_by_node(ds, rank, world_size))
54+
self._data = split_dataset_by_node(ds, rank, world_size)
5355
self._tokenizer = tokenizer
5456
self.seq_len = seq_len
57+
self.infinite = infinite
5558

5659
def __iter__(self):
5760
max_buffer_token_len = 1 + self.seq_len
5861
all_tokens: List[int] = []
5962

60-
for sample in self.data_iterator:
61-
sample_text = sample["text"]
62-
sample_tokens = self._tokenizer.encode(sample_text, bos=True, eos=True)
63-
all_tokens.extend(sample_tokens)
63+
while True:
64+
for sample in iter(self._data):
65+
sample_text = sample["text"]
66+
sample_tokens = self._tokenizer.encode(sample_text, bos=True, eos=True)
67+
all_tokens.extend(sample_tokens)
6468

65-
while len(all_tokens) >= max_buffer_token_len:
66-
x = torch.LongTensor(all_tokens[:max_buffer_token_len])
67-
# batched_x = x.reshape(self.batch_size, -1)
68-
# update tokens to the remaining tokens
69-
all_tokens = all_tokens[max_buffer_token_len:]
70-
input = x[:-1]
71-
label = x[1:]
72-
yield input, label
69+
while len(all_tokens) >= max_buffer_token_len:
70+
x = torch.LongTensor(all_tokens[:max_buffer_token_len])
71+
# update tokens to the remaining tokens
72+
all_tokens = all_tokens[max_buffer_token_len:]
73+
input = x[:-1]
74+
label = x[1:]
75+
yield input, label
76+
if not self.infinite:
77+
break
7378

7479

7580
def build_alpaca_data_loader(
76-
tokenizer: TokenizerIf, batch_size: int, seq_len: int, world_size, rank
81+
tokenizer: TokenizerIf,
82+
batch_size: int,
83+
seq_len: int,
84+
world_size: int,
85+
rank: int,
86+
infinite: bool = True,
7787
):
78-
alpaca_ds = AlpacaDataset(tokenizer, seq_len, world_size, rank)
88+
alpaca_ds = AlpacaDataset(tokenizer, seq_len, world_size, rank, infinite)
7989

8090
return DataLoader(alpaca_ds, batch_size=batch_size)

train.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,8 @@ def main(job_config: JobConfig):
167167
)
168168
checkpoint.load()
169169

170+
data_iterator = iter(data_loader)
171+
170172
with maybe_run_profiler(job_config) as torch_profiler:
171173
checkpoint.reset()
172174
# variables used to keep info for metrics logging
@@ -180,7 +182,7 @@ def main(job_config: JobConfig):
180182
train_state.step += 1
181183
# get batch
182184
data_load_start = timer()
183-
batch = next(iter(data_loader))
185+
batch = next(data_iterator)
184186
input_ids, labels = batch
185187
input_ids = input_ids.cuda()
186188
labels = labels.cuda()

0 commit comments

Comments
 (0)