File tree Expand file tree Collapse file tree 2 files changed +5
-3
lines changed
neural_compressor/experimental Expand file tree Collapse file tree 2 files changed +5
-3
lines changed Original file line number Diff line number Diff line change @@ -226,7 +226,9 @@ def pre_process(self):
226226 framework_specific_info = {'device' : self .cfg .device ,
227227 'random_seed' : self .cfg .tuning .random_seed ,
228228 'workspace_path' : self .cfg .tuning .workspace .path ,
229- 'q_dataloader' : None }
229+ 'q_dataloader' : None ,
230+ 'format' : 'default' ,
231+ 'backend' : 'default' }
230232
231233 if self .framework == 'tensorflow' :
232234 framework_specific_info .update (
Original file line number Diff line number Diff line change 99from neural_compressor .data import DATASETS
1010from neural_compressor .config import DistillationConfig , KnowledgeDistillationLossConfig
1111from neural_compressor .experimental .data .dataloaders .pytorch_dataloader import PyTorchDataLoader
12-
12+ from neural_compressor . adaptor . tf_utils . util import version1_lt_version2
1313
1414def build_fake_yaml ():
1515 fake_yaml = """
@@ -252,7 +252,7 @@ def test_distillation_external_new_API(self):
252252 stat = torch .load ('./saved/best_model.pt' )
253253 opt_model = self .student_model .load_state_dict (stat )
254254
255- @unittest .skipIf (tf .version .VERSION < '2.3.0' , " keras requires higher version than tf-2.3.0" )
255+ @unittest .skipIf (version1_lt_version2 ( tf .version .VERSION , '2.3.0' ) , " keras requires higher version than tf-2.3.0" )
256256 def test_tf_distillation (self ):
257257 from neural_compressor .experimental import Distillation
258258 from neural_compressor .conf .config import DistillationConf
You can’t perform that action at this time.
0 commit comments