5
5
#
6
6
# -----------------------------------------------------------------------------
7
7
8
+ import logging
8
9
import random
9
10
import warnings
10
11
from typing import Any , Dict , Optional , Union
17
18
import torch .utils .data
18
19
from peft import PeftModel , get_peft_model
19
20
from torch .optim .lr_scheduler import StepLR
20
- from transformers import AutoModel , AutoModelForCausalLM , AutoTokenizer
21
+ from transformers import AutoModel , AutoModelForCausalLM , AutoModelForSequenceClassification , AutoTokenizer
21
22
22
23
from QEfficient .finetune .configs .training import TrainConfig
23
24
from QEfficient .finetune .utils .config_utils import (
26
27
update_config ,
27
28
)
28
29
from QEfficient .finetune .utils .dataset_utils import get_dataloader
30
+ from QEfficient .finetune .utils .logging_utils import logger
29
31
from QEfficient .finetune .utils .parser import get_finetune_parser
30
- from QEfficient .finetune .utils .train_utils import get_longest_seq_length , print_model_size , train
31
- from QEfficient .utils ._utils import login_and_download_hf_lm
32
+ from QEfficient .finetune .utils .train_utils import (
33
+ get_longest_seq_length ,
34
+ print_model_size ,
35
+ print_trainable_parameters ,
36
+ train ,
37
+ )
38
+ from QEfficient .utils ._utils import hf_download
32
39
33
40
# Try importing QAIC-specific module, proceed without it if unavailable
34
41
try :
35
42
import torch_qaic # noqa: F401
36
43
except ImportError as e :
37
- print (f"Warning: { e } . Proceeding without QAIC modules." )
38
-
44
+ logger .log_rank_zero (f"{ e } . Moving ahead without these qaic modules." , logging .WARNING )
39
45
40
- from transformers import AutoModelForSequenceClassification
41
46
42
47
# Suppress all warnings
43
48
warnings .filterwarnings ("ignore" )
@@ -106,7 +111,8 @@ def load_model_and_tokenizer(
106
111
- Resizes model embeddings if tokenizer vocab size exceeds model embedding size.
107
112
- Sets pad_token_id to eos_token_id if not defined in the tokenizer.
108
113
"""
109
- pretrained_model_path = login_and_download_hf_lm (train_config .model_name )
114
+ logger .log_rank_zero (f"Loading HuggingFace model for { train_config .model_name } " )
115
+ pretrained_model_path = hf_download (train_config .model_name )
110
116
if train_config .task_type == "seq_classification" :
111
117
model = AutoModelForSequenceClassification .from_pretrained (
112
118
pretrained_model_path ,
@@ -116,7 +122,7 @@ def load_model_and_tokenizer(
116
122
)
117
123
118
124
if not hasattr (model , "base_model_prefix" ):
119
- raise RuntimeError ("Given huggingface model does not have 'base_model_prefix' attribute." )
125
+ logger . raise_error ("Given huggingface model does not have 'base_model_prefix' attribute." , RuntimeError )
120
126
121
127
for param in getattr (model , model .base_model_prefix ).parameters ():
122
128
param .requires_grad = False
@@ -141,11 +147,10 @@ def load_model_and_tokenizer(
141
147
# If there is a mismatch between tokenizer vocab size and embedding matrix,
142
148
# throw a warning and then expand the embedding matrix
143
149
if len (tokenizer ) > model .get_input_embeddings ().weight .shape [0 ]:
144
- print ( "WARNING: Resizing embedding matrix to match tokenizer vocab size." )
150
+ logger . log_rank_zero ( " Resizing the embedding matrix to match the tokenizer vocab size.", logging . WARNING )
145
151
model .resize_token_embeddings (len (tokenizer ))
146
152
147
- # FIXME (Meet): Cover below line inside the logger once it is implemented.
148
- print_model_size (model , train_config )
153
+ print_model_size (model )
149
154
150
155
# Note: Need to call this before calling PeftModel.from_pretrained or get_peft_model.
151
156
# Because, both makes model.is_gradient_checkpointing = True which is used in peft library to
@@ -157,7 +162,9 @@ def load_model_and_tokenizer(
157
162
if hasattr (model , "supports_gradient_checkpointing" ) and model .supports_gradient_checkpointing :
158
163
model .gradient_checkpointing_enable (gradient_checkpointing_kwargs = {"preserve_rng_state" : False })
159
164
else :
160
- raise RuntimeError ("Given model doesn't support gradient checkpointing. Please disable it and run it." )
165
+ logger .raise_error (
166
+ "Given model doesn't support gradient checkpointing. Please disable it and run it." , RuntimeError
167
+ )
161
168
162
169
model = apply_peft (model , train_config , peft_config_file , ** kwargs )
163
170
@@ -192,7 +199,7 @@ def apply_peft(
192
199
else :
193
200
peft_config = generate_peft_config (train_config , peft_config_file , ** kwargs )
194
201
model = get_peft_model (model , peft_config )
195
- model . print_trainable_parameters ()
202
+ print_trainable_parameters (model )
196
203
197
204
return model
198
205
@@ -217,25 +224,26 @@ def setup_dataloaders(
217
224
- Length of longest sequence in the dataset.
218
225
219
226
Raises:
220
- ValueError : If validation is enabled but the validation set is too small.
227
+ RuntimeError : If validation is enabled but the validation set is too small.
221
228
222
229
Notes:
223
230
- Applies a custom data collator if provided by get_custom_data_collator.
224
231
- Configures DataLoader kwargs using get_dataloader_kwargs for train and val splits.
225
232
"""
226
233
227
234
train_dataloader = get_dataloader (tokenizer , dataset_config , train_config , split = "train" )
228
- print (f"--> Num of Training Set Batches loaded = { len (train_dataloader )} " )
235
+ logger . log_rank_zero (f"Number of Training Set Batches loaded = { len (train_dataloader )} " )
229
236
230
237
eval_dataloader = None
231
238
if train_config .run_validation :
232
239
eval_dataloader = get_dataloader (tokenizer , dataset_config , train_config , split = "val" )
233
240
if len (eval_dataloader ) == 0 :
234
- raise ValueError (
235
- f"The eval set size is too small for dataloader to load even one batch. Please increase the size of eval set. ({ len (eval_dataloader )= } )"
241
+ logger .raise_error (
242
+ f"The eval set size is too small for dataloader to load even one batch. Please increase the size of eval set. ({ len (eval_dataloader )= } )" ,
243
+ ValueError ,
236
244
)
237
245
else :
238
- print (f"--> Num of Validation Set Batches loaded = { len (eval_dataloader )} " )
246
+ logger . log_rank_zero (f"Number of Validation Set Batches loaded = { len (eval_dataloader )} " )
239
247
240
248
longest_seq_length , _ = get_longest_seq_length (
241
249
torch .utils .data .ConcatDataset ([train_dataloader .dataset , eval_dataloader .dataset ])
@@ -274,13 +282,15 @@ def main(peft_config_file: str = None, **kwargs) -> None:
274
282
dataset_config = generate_dataset_config (train_config .dataset )
275
283
update_config (dataset_config , ** kwargs )
276
284
285
+ logger .prepare_for_logs (train_config .output_dir , train_config .dump_logs , train_config .log_level )
286
+
277
287
setup_distributed_training (train_config )
278
288
setup_seeds (train_config .seed )
279
289
model , tokenizer = load_model_and_tokenizer (train_config , dataset_config , peft_config_file , ** kwargs )
280
290
281
291
# Create DataLoaders for the training and validation dataset
282
292
train_dataloader , eval_dataloader , longest_seq_length = setup_dataloaders (train_config , dataset_config , tokenizer )
283
- print (
293
+ logger . log_rank_zero (
284
294
f"The longest sequence length in the train data is { longest_seq_length } , "
285
295
f"passed context length is { train_config .context_length } and overall model's context length is "
286
296
f"{ model .config .max_position_embeddings } "
0 commit comments