File tree Expand file tree Collapse file tree 2 files changed +3
-3
lines changed Expand file tree Collapse file tree 2 files changed +3
-3
lines changed Original file line number Diff line number Diff line change @@ -744,7 +744,7 @@ def main():
744744 distributed = args .distributed ,
745745 collate_fn = collate_fn ,
746746 pin_memory = args .pin_mem ,
747- img_dtype = model_dtype ,
747+ img_dtype = model_dtype or torch . float32 ,
748748 device = device ,
749749 use_prefetcher = args .prefetcher ,
750750 use_multi_epochs_loader = args .use_multi_epochs_loader ,
@@ -769,7 +769,7 @@ def main():
769769 distributed = args .distributed ,
770770 crop_pct = data_config ['crop_pct' ],
771771 pin_memory = args .pin_mem ,
772- img_dtype = model_dtype ,
772+ img_dtype = model_dtype or torch . float32 ,
773773 device = device ,
774774 use_prefetcher = args .prefetcher ,
775775 )
Original file line number Diff line number Diff line change @@ -307,7 +307,7 @@ def validate(args):
307307 crop_border_pixels = args .crop_border_pixels ,
308308 pin_memory = args .pin_mem ,
309309 device = device ,
310- img_dtype = model_dtype ,
310+ img_dtype = model_dtype or torch . float32 ,
311311 tf_preprocessing = args .tf_preprocessing ,
312312 )
313313
You can’t perform that action at this time.
0 commit comments