@@ -20,6 +20,7 @@ class AlpacaDataset(IterableDataset):
2020 seq_len (int): max sequence length
2121 world_size (int): number of data parallel processes participating in training
2222 rank (int): rank of the current data parallel process
23+ infinite: whether to loop infinitely over the dataset
2324
2425 Data input format:
2526 {
@@ -43,38 +44,47 @@ def __init__(
4344 seq_len : int = 2048 ,
4445 world_size : int = 1 ,
4546 rank : int = 0 ,
47+ infinite : bool = False ,
4648 ** kwargs
4749 ) -> None :
4850 # TODO: This is a temporary solution for small datasets like Alpaca.
4951 # For larger datasets we need to use a more scalable approach.
5052 # Setting `streaming=True` works for large dataset, but the speed is slow.
5153 ds = load_dataset ("tatsu-lab/alpaca" , split = "train" )
52- self .data_iterator = iter ( split_dataset_by_node (ds , rank , world_size ) )
54+ self ._data = split_dataset_by_node (ds , rank , world_size )
5355 self ._tokenizer = tokenizer
5456 self .seq_len = seq_len
57+ self .infinite = infinite
5558
5659 def __iter__ (self ):
5760 max_buffer_token_len = 1 + self .seq_len
5861 all_tokens : List [int ] = []
5962
60- for sample in self .data_iterator :
61- sample_text = sample ["text" ]
62- sample_tokens = self ._tokenizer .encode (sample_text , bos = True , eos = True )
63- all_tokens .extend (sample_tokens )
63+ while True :
64+ for sample in iter (self ._data ):
65+ sample_text = sample ["text" ]
66+ sample_tokens = self ._tokenizer .encode (sample_text , bos = True , eos = True )
67+ all_tokens .extend (sample_tokens )
6468
65- while len (all_tokens ) >= max_buffer_token_len :
66- x = torch .LongTensor (all_tokens [:max_buffer_token_len ])
67- # batched_x = x.reshape(self.batch_size, -1)
68- # update tokens to the remaining tokens
69- all_tokens = all_tokens [max_buffer_token_len :]
70- input = x [:- 1 ]
71- label = x [1 :]
72- yield input , label
69+ while len (all_tokens ) >= max_buffer_token_len :
70+ x = torch .LongTensor (all_tokens [:max_buffer_token_len ])
71+ # update tokens to the remaining tokens
72+ all_tokens = all_tokens [max_buffer_token_len :]
73+ input = x [:- 1 ]
74+ label = x [1 :]
75+ yield input , label
76+ if not self .infinite :
77+ break
7378
7479
7580def build_alpaca_data_loader (
76- tokenizer : TokenizerIf , batch_size : int , seq_len : int , world_size , rank
81+ tokenizer : TokenizerIf ,
82+ batch_size : int ,
83+ seq_len : int ,
84+ world_size : int ,
85+ rank : int ,
86+ infinite : bool = True ,
7787):
78- alpaca_ds = AlpacaDataset (tokenizer , seq_len , world_size , rank )
88+ alpaca_ds = AlpacaDataset (tokenizer , seq_len , world_size , rank , infinite )
7989
8090 return DataLoader (alpaca_ds , batch_size = batch_size )
0 commit comments