1717from  InnerEye .ML .SSL .datamodules_and_datasets .datamodules  import  CombinedDataModule , InnerEyeVisionDataModule 
1818from  InnerEye .ML .SSL .datamodules_and_datasets .transforms_utils  import  InnerEyeCIFARLinearHeadTransform , \
1919    InnerEyeCIFARTrainTransform , \
20-     get_cxr_ssl_transforms 
20+     get_ssl_transforms_from_config 
2121from  InnerEye .ML .SSL .encoders  import  get_encoder_output_dim 
2222from  InnerEye .ML .SSL .lightning_modules .byol .byol_module  import  BYOLInnerEye 
2323from  InnerEye .ML .SSL .lightning_modules .simclr_module  import  SimCLRInnerEye 
@@ -96,6 +96,7 @@ class SSLContainer(LightningContainer):
9696    learning_rate_linear_head_during_ssl_training  =  param .Number (default = 1e-4 ,
9797                                                                 doc = "Learning rate for linear head training during " 
9898                                                                     "SSL training." )
99+     drop_last  =  param .Boolean (default = True , doc = "If True drops the last incomplete batch" )
99100
100101    def  setup (self ) ->  None :
101102        from  InnerEye .ML .SSL .lightning_containers .ssl_image_classifier  import  SSLClassifierContainer 
@@ -166,8 +167,8 @@ def create_model(self) -> LightningModule:
166167                f"Found { self .ssl_training_type .value }  )
167168        model .hparams .update ({'ssl_type' : self .ssl_training_type .value ,
168169                              "num_classes" : self .data_module .num_classes })
169-         self .encoder_output_dim  =  get_encoder_output_dim (model , self .data_module )
170170
171+         self .encoder_output_dim  =  get_encoder_output_dim (model , self .data_module )
171172        return  model 
172173
173174    def  get_data_module (self ) ->  InnerEyeDataModuleTypes :
@@ -186,7 +187,7 @@ def _create_ssl_data_modules(self, is_ssl_encoder_module: bool) -> InnerEyeVisio
186187        """ 
187188        Returns torch lightning data module for encoder or linear head 
188189
189-         :param is_ssl_encoder_module: whether to return the data module for SSL training or for linear heard . If true, 
190+         :param is_ssl_encoder_module: whether to return the data module for SSL training or for linear head . If true, 
190191        :return transforms with two views per sample (batch like (img_v1, img_v2, label)). If False, return only one 
191192        view per sample but also return the index of the sample in the dataset (to make sure we don't use twice the same 
192193        batch in one training epoch (batch like (index, img_v1, label), as classifier dataloader expected to be shorter 
@@ -209,7 +210,8 @@ def _create_ssl_data_modules(self, is_ssl_encoder_module: bool) -> InnerEyeVisio
209210                                      data_dir = str (datamodule_args .dataset_path ),
210211                                      batch_size = batch_size_per_gpu ,
211212                                      num_workers = self .num_workers ,
212-                                       seed = self .random_seed )
213+                                       seed = self .random_seed ,
214+                                       drop_last = self .drop_last )
213215        dm .prepare_data ()
214216        dm .setup ()
215217        return  dm 
@@ -223,25 +225,39 @@ def _get_transforms(self, augmentation_config: Optional[CfgNode],
223225        examples. 
224226        :param dataset_name: name of the dataset, value has to be in SSLDatasetName, determines which transformation 
225227        pipeline to return. 
226-         :param is_ssl_encoder_module: if True the transformation pipeline will yield two version of the image it is 
227-         applied on. If False, return only one transformation. 
228+         :param is_ssl_encoder_module: if True the transformation pipeline will yield two versions of the image it is 
229+         applied on and it applies the training transformations also at validation time. Note that if your transformation  
230+         does not contain any randomness, the pipeline will return two identical copies. If False, it will return only one  
231+         transformation. 
228232        :return: training transformation pipeline and validation transformation pipeline. 
229233        """ 
230234        if  dataset_name  in  [SSLDatasetName .RSNAKaggleCXR .value ,
231235                            SSLDatasetName .NIHCXR .value ,
232236                            SSLDatasetName .CheXpert .value ,
233237                            SSLDatasetName .Covid .value ]:
234238            assert  augmentation_config  is  not None 
235-             train_transforms , val_transforms  =  get_cxr_ssl_transforms (augmentation_config ,
236-                                                                       return_two_views_per_sample = is_ssl_encoder_module ,
237-                                                                       use_training_augmentations_for_validation = is_ssl_encoder_module )
239+             train_transforms , val_transforms  =  get_ssl_transforms_from_config (
240+                 augmentation_config ,
241+                 return_two_views_per_sample = is_ssl_encoder_module ,
242+                 use_training_augmentations_for_validation = is_ssl_encoder_module 
243+             )
238244        elif  dataset_name  in  [SSLDatasetName .CIFAR10 .value , SSLDatasetName .CIFAR100 .value ]:
239245            train_transforms  =  \
240246                InnerEyeCIFARTrainTransform (32 ) if  is_ssl_encoder_module  else  InnerEyeCIFARLinearHeadTransform (32 )
241247            val_transforms  =  \
242248                InnerEyeCIFARTrainTransform (32 ) if  is_ssl_encoder_module  else  InnerEyeCIFARLinearHeadTransform (32 )
249+         elif  augmentation_config :
250+             train_transforms , val_transforms  =  get_ssl_transforms_from_config (
251+                 augmentation_config ,
252+                 return_two_views_per_sample = is_ssl_encoder_module ,
253+                 use_training_augmentations_for_validation = is_ssl_encoder_module ,
254+                 expand_channels = False ,
255+             )
256+             logging .warning (f"Dataset { dataset_name }  
257+                             f"get_ssl_transforms() to create the augmentation pipeline, make sure " 
258+                             f"the transformations in your configs are compatible. " )
243259        else :
244-             raise  ValueError (f"Dataset { dataset_name }  )
260+             raise  ValueError (f"Dataset { dataset_name }  and no config has been passed ." )
245261
246262        return  train_transforms , val_transforms 
247263
0 commit comments