Skip to content

Commit 846c45f

Browse files
committed
fix distillation ut issue
Signed-off-by: Lv, Liang1 <[email protected]>
1 parent 34b8ece commit 846c45f

File tree

2 files changed

+5
-3
lines changed

2 files changed

+5
-3
lines changed

neural_compressor/experimental/distillation.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -226,7 +226,9 @@ def pre_process(self):
226226
framework_specific_info = {'device': self.cfg.device,
227227
'random_seed': self.cfg.tuning.random_seed,
228228
'workspace_path': self.cfg.tuning.workspace.path,
229-
'q_dataloader': None}
229+
'q_dataloader': None,
230+
'format': 'default',
231+
'backend': 'default'}
230232

231233
if self.framework == 'tensorflow':
232234
framework_specific_info.update(

test/distillation/test_distillation.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from neural_compressor.data import DATASETS
1010
from neural_compressor.config import DistillationConfig, KnowledgeDistillationLossConfig
1111
from neural_compressor.experimental.data.dataloaders.pytorch_dataloader import PyTorchDataLoader
12-
12+
from neural_compressor.adaptor.tf_utils.util import version1_lt_version2
1313

1414
def build_fake_yaml():
1515
fake_yaml = """
@@ -252,7 +252,7 @@ def test_distillation_external_new_API(self):
252252
stat = torch.load('./saved/best_model.pt')
253253
opt_model = self.student_model.load_state_dict(stat)
254254

255-
@unittest.skipIf(tf.version.VERSION < '2.3.0', " keras requires higher version than tf-2.3.0")
255+
@unittest.skipIf(version1_lt_version2(tf.version.VERSION, '2.3.0'), " keras requires higher version than tf-2.3.0")
256256
def test_tf_distillation(self):
257257
from neural_compressor.experimental import Distillation
258258
from neural_compressor.conf.config import DistillationConf

0 commit comments

Comments
 (0)