3434from transformers import PreTrainedModel
3535from transformers .training_args_tf import TFTrainingArguments
3636from typing import Callable , Optional , List
37- from .utils .utility_tf import TFDataloader , TMPPATH
37+ from .utils .utility_tf import TFDataloader , TMPPATH , get_filepath
3838
3939tf = LazyImport ("tensorflow" )
4040logger = logging .getLogger (__name__ )
@@ -50,6 +50,8 @@ def __init__(
5050 compute_metrics : Optional [Callable ] = None ,
5151 criterion = None ,
5252 optimizer = None ,
53+ task_type = None ,
54+ task_id = None ,
5355 ):
5456 """
5557 Args:
@@ -78,11 +80,14 @@ def __init__(
7880 self .compute_metrics = compute_metrics
7981 self .args = args
8082 self .optimizer = optimizer
83+ self .task_type = task_type
84+ self .task_id = task_id
8185 self .criterion = criterion if criterion is not None else \
8286 self .model .loss if hasattr (self .model , "loss" ) else None
83- self .model .save_pretrained (TMPPATH , saved_model = True )
87+ self .model .save_pretrained (get_filepath ( TMPPATH , self . task_type , self . task_id ) , saved_model = True )
8488 _ , self .input_names , self .output_names = saved_model_session (
85- os .path .join (TMPPATH ,"saved_model/1" ), input_tensor_names = [], output_tensor_names = [])
89+ os .path .join (get_filepath (TMPPATH , self .task_type , self .task_id ), "saved_model/1" ), input_tensor_names = [],
90+ output_tensor_names = [])
8691 self .eval_distributed = False
8792
8893 @property
@@ -298,7 +303,8 @@ def init_quantizer(
298303 self .metrics = self .quant_config .metrics
299304
300305 quantizer = Quantization (self .quant_config .inc_config )
301- quantizer .model = common .Model (os .path .join (TMPPATH ,"saved_model/1" ), modelType = "saved_model" )
306+ quantizer .model = common .Model (
307+ os .path .join (get_filepath (TMPPATH , self .task_type , self .task_id ),"saved_model/1" ), modelType = "saved_model" )
302308
303309 self .quantizer = quantizer
304310 return quantizer
@@ -325,8 +331,7 @@ def _inc_quantize(
325331 batch_size = self .args .per_device_eval_batch_size )
326332 else : # pragma: no cover
327333 assert False , "Please pass calibration dataset to TFNoTrainerOptimizer.calib_dataloader"
328- elif self .quant_config .approach == QuantizationMode .QUANTIZATIONAWARETRAINING .value :
329- # pragma: no cover
334+ elif self .quant_config .approach == QuantizationMode .QUANTIZATIONAWARETRAINING .value : # pragma: no cover
330335 assert False , \
331336 "Unsupport quantization aware training for tensorflow framework"
332337
@@ -369,7 +374,7 @@ def init_pruner(
369374 "please pass a instance of PruningConfig to trainer.prune!"
370375
371376 pruner = Pruning (self .pruning_config .inc_config )
372- pruner .model = os .path .join (TMPPATH ,"saved_model/1" )
377+ pruner .model = os .path .join (get_filepath ( TMPPATH , self . task_type , self . task_id ) ,"saved_model/1" )
373378 pruner .model .model_type = "saved_model"
374379
375380 self .pruner = pruner
@@ -416,7 +421,11 @@ def prune(
416421
417422 opt_model = self .pruner .fit ()
418423
419- return self .model
424+ opt_model .save (self .args .output_dir )
425+ logger .info (
426+ "pruned model have saved to {}" .format (self .args .output_dir )
427+ )
428+ return opt_model .model
420429
421430 def init_distiller (
422431 self ,
@@ -506,4 +515,4 @@ def on_train_batch_end(self, batch, logs=None):
506515 callbacks = [PruningCb ()])
507516
508517 self .pruner .model ._sess = None
509- input_model .save_pretrained (TMPPATH , saved_model = True )
518+ input_model .save_pretrained (get_filepath ( TMPPATH , self . task_type , self . task_id ) , saved_model = True )
0 commit comments