178178                   help = 'lower precision AMP dtype (default: float16)' )
179179group .add_argument ('--amp-impl' , default = 'native' , type = str ,
180180                   help = 'AMP impl to use, "native" or "apex" (default: native)' )
181+ group .add_argument ('--model-dtype' , default = None , type = str ,
182+                    help = 'Model dtype override (non-AMP) (default: float32)' )
181183group .add_argument ('--no-ddp-bb' , action = 'store_true' , default = False ,
182184                   help = 'Force broadcast buffers for native DDP to off.' )
183185group .add_argument ('--synchronize-step' , action = 'store_true' , default = False ,
@@ -436,10 +438,18 @@ def main():
436438        _logger .info (f'Training with a single process on 1 device ({ args .device }  ).' )
437439    assert  args .rank  >=  0 
438440
441+     model_dtype  =  None 
442+     if  args .model_dtype :
443+         assert  args .model_dtype  in  ('float32' , 'float16' , 'bfloat16' )
444+         model_dtype  =  getattr (torch , args .model_dtype )
445+         if  model_dtype  ==  torch .float16 :
446+             _logger .warning ('float16 is not recommended for training, for half precision bfloat16 is recommended.' )
447+ 
439448    # resolve AMP arguments based on PyTorch / Apex availability 
440449    use_amp  =  None 
441450    amp_dtype  =  torch .float16 
442451    if  args .amp :
452+         assert  model_dtype  is  None  or  model_dtype  ==  torch .float32 , 'float32 model dtype must be used with AMP' 
443453        if  args .amp_impl  ==  'apex' :
444454            assert  has_apex , 'AMP impl specified as APEX but APEX is not installed.' 
445455            use_amp  =  'apex' 
@@ -519,7 +529,7 @@ def main():
519529        model  =  convert_splitbn_model (model , max (num_aug_splits , 2 ))
520530
521531    # move model to GPU, enable channels last layout if set 
522-     model .to (device = device ) 
532+     model .to (device = device ,  dtype = model_dtype )   # FIXME move model device & dtype into create_model 
523533    if  args .channels_last :
524534        model .to (memory_format = torch .channels_last )
525535
@@ -589,7 +599,7 @@ def main():
589599            _logger .info ('Using native Torch AMP. Training in mixed precision.' )
590600    else :
591601        if  utils .is_primary (args ):
592-             _logger .info ('AMP not enabled. Training in float32.' )
602+             _logger .info (f 'AMP not enabled. Training in { model_dtype   or   torch . float32 }  .' )
593603
594604    # optionally resume from a checkpoint 
595605    resume_epoch  =  None 
@@ -734,6 +744,7 @@ def main():
734744        distributed = args .distributed ,
735745        collate_fn = collate_fn ,
736746        pin_memory = args .pin_mem ,
747+         img_dtype = model_dtype ,
737748        device = device ,
738749        use_prefetcher = args .prefetcher ,
739750        use_multi_epochs_loader = args .use_multi_epochs_loader ,
@@ -758,6 +769,7 @@ def main():
758769            distributed = args .distributed ,
759770            crop_pct = data_config ['crop_pct' ],
760771            pin_memory = args .pin_mem ,
772+             img_dtype = model_dtype ,
761773            device = device ,
762774            use_prefetcher = args .prefetcher ,
763775        )
@@ -822,21 +834,21 @@ def main():
822834        with  open (os .path .join (output_dir , 'args.yaml' ), 'w' ) as  f :
823835            f .write (args_text )
824836
825-     if   utils . is_primary ( args )  and  args .log_wandb :
826-         if  has_wandb :
827-             assert  not  args .wandb_resume_id  or  args .resume 
828-             wandb .init (
829-                 project = args .wandb_project ,
830-                 name = args . experiment ,
831-                 config = args ,
832-                 tags = args .wandb_tags ,
833-                 resume = "must"  if  args .wandb_resume_id  else  None ,
834-                 id = args .wandb_resume_id  if  args .wandb_resume_id  else  None ,
835-             )
836-         else :
837-             _logger .warning (
838-                 "You've requested to log metrics to wandb but package not found. " 
839-                 "Metrics not being logged to wandb, try `pip install wandb`" )
837+          if  args .log_wandb :
838+              if  has_wandb :
839+                  assert  not  args .wandb_resume_id  or  args .resume 
840+                  wandb .init (
841+                      project = args .wandb_project ,
842+                      name = exp_name ,
843+                      config = args ,
844+                      tags = args .wandb_tags ,
845+                      resume = "must"  if  args .wandb_resume_id  else  None ,
846+                      id = args .wandb_resume_id  if  args .wandb_resume_id  else  None ,
847+                  )
848+              else :
849+                  _logger .warning (
850+                      "You've requested to log metrics to wandb but package not found. " 
851+                      "Metrics not being logged to wandb, try `pip install wandb`" )
840852
841853    # setup learning rate schedule and starting epoch 
842854    updates_per_epoch  =  (len (loader_train ) +  args .grad_accum_steps  -  1 ) //  args .grad_accum_steps 
@@ -886,6 +898,7 @@ def main():
886898                output_dir = output_dir ,
887899                amp_autocast = amp_autocast ,
888900                loss_scaler = loss_scaler ,
901+                 model_dtype = model_dtype ,
889902                model_ema = model_ema ,
890903                mixup_fn = mixup_fn ,
891904                num_updates_total = num_epochs  *  updates_per_epoch ,
@@ -904,6 +917,7 @@ def main():
904917                    args ,
905918                    device = device ,
906919                    amp_autocast = amp_autocast ,
920+                     model_dtype = model_dtype ,
907921                )
908922
909923                if  model_ema  is  not   None  and  not  args .model_ema_force_cpu :
@@ -986,6 +1000,7 @@ def train_one_epoch(
9861000        output_dir = None ,
9871001        amp_autocast = suppress ,
9881002        loss_scaler = None ,
1003+         model_dtype = None ,
9891004        model_ema = None ,
9901005        mixup_fn = None ,
9911006        num_updates_total = None ,
@@ -1022,7 +1037,7 @@ def train_one_epoch(
10221037            accum_steps  =  last_accum_steps 
10231038
10241039        if  not  args .prefetcher :
1025-             input , target  =  input .to (device ), target .to (device )
1040+             input , target  =  input .to (device = device ,  dtype = model_dtype ), target .to (device = device )
10261041            if  mixup_fn  is  not   None :
10271042                input , target  =  mixup_fn (input , target )
10281043        if  args .channels_last :
@@ -1149,6 +1164,7 @@ def validate(
11491164        args ,
11501165        device = torch .device ('cuda' ),
11511166        amp_autocast = suppress ,
1167+         model_dtype = None ,
11521168        log_suffix = '' 
11531169):
11541170    batch_time_m  =  utils .AverageMeter ()
@@ -1164,8 +1180,8 @@ def validate(
11641180        for  batch_idx , (input , target ) in  enumerate (loader ):
11651181            last_batch  =  batch_idx  ==  last_idx 
11661182            if  not  args .prefetcher :
1167-                 input  =  input .to (device )
1168-                 target  =  target .to (device )
1183+                 input  =  input .to (device = device ,  dtype = model_dtype )
1184+                 target  =  target .to (device = device )
11691185            if  args .channels_last :
11701186                input  =  input .contiguous (memory_format = torch .channels_last )
11711187
0 commit comments