1414# limitations under the License.
1515#
1616import datasets
17- import torch
18- from peft import LoraConfig
19- from transformers import AutoModelForCausalLM
2017from typing import Dict
2118
2219IGNORE_INDEX = - 100
2320
2421
25- class DPOIntelOrcaProcesser :
22+ class DPOIntelOrcaPreprocesser :
2623 @staticmethod
2724 def tokenize_dataset (config , tokenizer , dataset ):
2825 tokenizer .pad_token = tokenizer .eos_token
29- if isinstance (dataset , datasets .Dataset ):
30- column_names = dataset .column_names
31-
32- if isinstance (dataset , datasets .DatasetDict ):
33- column_names = dataset ["train" ].column_names
3426
3527 def return_prompt_and_responses (samples ) -> Dict [str , str ]:
3628 return {
@@ -44,15 +36,11 @@ def return_prompt_and_responses(samples) -> Dict[str, str]:
4436 "rejected" : samples ["rejected" ],
4537 }
4638
47- raw_datasets = dataset .map (
39+ dataset = dataset .map (
4840 return_prompt_and_responses ,
49- remove_columns = column_names ,
5041 load_from_cache_file = False ,
5142 desc = "Tokenize dataset" ,
5243 )
53- train_dataset = raw_datasets ["train" ]
54- column_names = train_dataset .column_names
55-
5644 """
5745 Copied from https://github.com/intel/intel-extension-for-transformers/blob/5ba5fa8048b63bec8a3be8a7122a3db8344ad065/
5846 intel_extension_for_transformers/neural_chat/examples/finetuning/dpo_pipeline/dpo_clm.py#L308
@@ -145,6 +133,8 @@ def preprocess_function(examples):
145133
146134 return examples
147135
136+ train_dataset = dataset ["train" ]
137+ column_names = list (train_dataset .features )
148138 if train_dataset is not None :
149139 # Create train feature from dataset
150140 train_dataset = train_dataset .map (
@@ -154,7 +144,7 @@ def preprocess_function(examples):
154144 desc = "Running tokenizer on train dataset" ,
155145 )
156146
157- eval_dataset = raw_datasets .get ("validation" )
147+ eval_dataset = dataset .get ("validation" )
158148
159149 if eval_dataset is not None :
160150 column_names = eval_dataset .column_names
@@ -167,78 +157,3 @@ def preprocess_function(examples):
167157 tokenized_datasets = {"train" : train_dataset , "validation" : eval_dataset }
168158
169159 return tokenized_datasets
170-
171-
172- class DPOFuneTuning :
173- def __init__ (self , config ):
174- self .config = config
175- self .torch_dtype = (
176- self .config ["Dataset" ]["torch_dtype" ]
177- if self .config ["Dataset" ]["torch_dtype" ] in ["auto" , None ]
178- else getattr (torch , self .config ["Dataset" ]["torch_dtype" ])
179- )
180-
181- def get_model (self ):
182- # load policy model
183- model = AutoModelForCausalLM .from_pretrained (
184- self .config ["General" ]["base_model" ],
185- config = self .config ,
186- low_cpu_mem_usage = True ,
187- torch_dtype = self .torch_dtype ,
188- use_auth_token = True if self .config ["General" ]["config" ]["use_auth_token" ] else None ,
189- )
190- model .config .use_cache = False
191- return model
192-
193- def get_model_ref (self ):
194- # load reference model
195- model_ref = AutoModelForCausalLM .from_pretrained (
196- self .config ["General" ]["base_model" ],
197- config = self .config ,
198- low_cpu_mem_usage = True ,
199- torch_dtype = self .torch_dtype ,
200- use_auth_token = True if self .config ["General" ]["config" ]["use_auth_token" ] else None ,
201- )
202- model_ref .config .use_cache = False
203- return model_ref
204-
205- def dpo_train (self , training_args , tokenized_datasets , tokenizer ):
206- from trl import DPOTrainer
207-
208- lora_config = self .config ["General" ].get ("lora_config" , None )
209- return DPOTrainer (
210- self .get_model (),
211- self .get_model_ref () if lora_config is not None else None ,
212- args = training_args ,
213- beta = self .config ["Training" ].get ("beta" ),
214- train_dataset = tokenized_datasets ["train" ],
215- eval_dataset = tokenized_datasets ["validation" ]
216- if tokenized_datasets .get ("validation" ) is not None
217- else None ,
218- tokenizer = tokenizer ,
219- peft_config = LoraConfig (** lora_config ) if lora_config is not None else None ,
220- max_length = self .config ["Dataset" ].get ("max_length" ),
221- max_prompt_length = self .config ["Dataset" ].get ("max_prompt_length" ),
222- )
223-
224-
225- class GaudiDPOFuneTuning (DPOFuneTuning ):
226- def dpo_train (self , training_args , gaudi_config , tokenized_datasets , tokenizer ):
227- from optimum .habana .trl import GaudiDPOTrainer as DPOTrainer
228-
229- lora_config = self .config ["General" ].get ("lora_config" , None )
230- return DPOTrainer (
231- self .get_model (),
232- self .get_model_ref () if lora_config is not None else None ,
233- args = training_args ,
234- gaudi_config = gaudi_config ,
235- beta = self .config ["Training" ].get ("beta" ),
236- train_dataset = tokenized_datasets ["train" ],
237- eval_dataset = tokenized_datasets ["validation" ]
238- if tokenized_datasets .get ("validation" ) is not None
239- else None ,
240- tokenizer = tokenizer ,
241- peft_config = LoraConfig (** lora_config ) if lora_config is not None else None ,
242- max_length = self .config ["Dataset" ].get ("max_length" ),
243- max_prompt_length = self .config ["Dataset" ].get ("max_prompt_length" ),
244- )
0 commit comments