Skip to content
This repository was archived by the owner on Sep 23, 2025. It is now read-only.

Commit 05fdf80

Browse files
committed
refactor
Signed-off-by: minmingzhu <[email protected]>
1 parent c506398 commit 05fdf80

File tree

2 files changed

+5
-401
lines changed

2 files changed

+5
-401
lines changed

llm_on_ray/finetune/dpo_funetuing.py renamed to llm_on_ray/finetune/data_preprocess.py

Lines changed: 5 additions & 90 deletions
Original file line numberDiff line numberDiff line change
@@ -14,23 +14,15 @@
1414
# limitations under the License.
1515
#
1616
import datasets
17-
import torch
18-
from peft import LoraConfig
19-
from transformers import AutoModelForCausalLM
2017
from typing import Dict
2118

2219
IGNORE_INDEX = -100
2320

2421

25-
class DPOIntelOrcaProcesser:
22+
class DPOIntelOrcaPreprocesser:
2623
@staticmethod
2724
def tokenize_dataset(config, tokenizer, dataset):
2825
tokenizer.pad_token = tokenizer.eos_token
29-
if isinstance(dataset, datasets.Dataset):
30-
column_names = dataset.column_names
31-
32-
if isinstance(dataset, datasets.DatasetDict):
33-
column_names = dataset["train"].column_names
3426

3527
def return_prompt_and_responses(samples) -> Dict[str, str]:
3628
return {
@@ -44,15 +36,11 @@ def return_prompt_and_responses(samples) -> Dict[str, str]:
4436
"rejected": samples["rejected"],
4537
}
4638

47-
raw_datasets = dataset.map(
39+
dataset = dataset.map(
4840
return_prompt_and_responses,
49-
remove_columns=column_names,
5041
load_from_cache_file=False,
5142
desc="Tokenize dataset",
5243
)
53-
train_dataset = raw_datasets["train"]
54-
column_names = train_dataset.column_names
55-
5644
"""
5745
Copied from https://github.com/intel/intel-extension-for-transformers/blob/5ba5fa8048b63bec8a3be8a7122a3db8344ad065/
5846
intel_extension_for_transformers/neural_chat/examples/finetuning/dpo_pipeline/dpo_clm.py#L308
@@ -145,6 +133,8 @@ def preprocess_function(examples):
145133

146134
return examples
147135

136+
train_dataset = dataset["train"]
137+
column_names = list(train_dataset.features)
148138
if train_dataset is not None:
149139
# Create train feature from dataset
150140
train_dataset = train_dataset.map(
@@ -154,7 +144,7 @@ def preprocess_function(examples):
154144
desc="Running tokenizer on train dataset",
155145
)
156146

157-
eval_dataset = raw_datasets.get("validation")
147+
eval_dataset = dataset.get("validation")
158148

159149
if eval_dataset is not None:
160150
column_names = eval_dataset.column_names
@@ -167,78 +157,3 @@ def preprocess_function(examples):
167157
tokenized_datasets = {"train": train_dataset, "validation": eval_dataset}
168158

169159
return tokenized_datasets
170-
171-
172-
class DPOFuneTuning:
173-
def __init__(self, config):
174-
self.config = config
175-
self.torch_dtype = (
176-
self.config["Dataset"]["torch_dtype"]
177-
if self.config["Dataset"]["torch_dtype"] in ["auto", None]
178-
else getattr(torch, self.config["Dataset"]["torch_dtype"])
179-
)
180-
181-
def get_model(self):
182-
# load policy model
183-
model = AutoModelForCausalLM.from_pretrained(
184-
self.config["General"]["base_model"],
185-
config=self.config,
186-
low_cpu_mem_usage=True,
187-
torch_dtype=self.torch_dtype,
188-
use_auth_token=True if self.config["General"]["config"]["use_auth_token"] else None,
189-
)
190-
model.config.use_cache = False
191-
return model
192-
193-
def get_model_ref(self):
194-
# load reference model
195-
model_ref = AutoModelForCausalLM.from_pretrained(
196-
self.config["General"]["base_model"],
197-
config=self.config,
198-
low_cpu_mem_usage=True,
199-
torch_dtype=self.torch_dtype,
200-
use_auth_token=True if self.config["General"]["config"]["use_auth_token"] else None,
201-
)
202-
model_ref.config.use_cache = False
203-
return model_ref
204-
205-
def dpo_train(self, training_args, tokenized_datasets, tokenizer):
206-
from trl import DPOTrainer
207-
208-
lora_config = self.config["General"].get("lora_config", None)
209-
return DPOTrainer(
210-
self.get_model(),
211-
self.get_model_ref() if lora_config is not None else None,
212-
args=training_args,
213-
beta=self.config["Training"].get("beta"),
214-
train_dataset=tokenized_datasets["train"],
215-
eval_dataset=tokenized_datasets["validation"]
216-
if tokenized_datasets.get("validation") is not None
217-
else None,
218-
tokenizer=tokenizer,
219-
peft_config=LoraConfig(**lora_config) if lora_config is not None else None,
220-
max_length=self.config["Dataset"].get("max_length"),
221-
max_prompt_length=self.config["Dataset"].get("max_prompt_length"),
222-
)
223-
224-
225-
class GaudiDPOFuneTuning(DPOFuneTuning):
226-
def dpo_train(self, training_args, gaudi_config, tokenized_datasets, tokenizer):
227-
from optimum.habana.trl import GaudiDPOTrainer as DPOTrainer
228-
229-
lora_config = self.config["General"].get("lora_config", None)
230-
return DPOTrainer(
231-
self.get_model(),
232-
self.get_model_ref() if lora_config is not None else None,
233-
args=training_args,
234-
gaudi_config=gaudi_config,
235-
beta=self.config["Training"].get("beta"),
236-
train_dataset=tokenized_datasets["train"],
237-
eval_dataset=tokenized_datasets["validation"]
238-
if tokenized_datasets.get("validation") is not None
239-
else None,
240-
tokenizer=tokenizer,
241-
peft_config=LoraConfig(**lora_config) if lora_config is not None else None,
242-
max_length=self.config["Dataset"].get("max_length"),
243-
max_prompt_length=self.config["Dataset"].get("max_prompt_length"),
244-
)

0 commit comments

Comments
 (0)