Skip to content
This repository was archived by the owner on Mar 21, 2024. It is now read-only.

Commit b97ca87

Browse files
Merge branch 'main' into vsalva/precommit-hooks
2 parents d1d2a7f + 521c004 commit b97ca87

File tree

9 files changed

+111
-60
lines changed

9 files changed

+111
-60
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ jobs that run in AzureML.
2626

2727
### Changed
2828
- ([#531](https://github.com/microsoft/InnerEye-DeepLearning/pull/531)) Updated PL to 1.3.8, torchmetrics and pl-bolts and changed relevant metrics and SSL code API.
29+
- ([#555](https://github.com/microsoft/InnerEye-DeepLearning/pull/555)) Make the SSLContainer compatible with new datasets
2930
- ([#533](https://github.com/microsoft/InnerEye-DeepLearning/pull/533)) Better defaults for inference on ensemble children.
3031
- ([#536](https://github.com/microsoft/InnerEye-DeepLearning/pull/536)) Inference will not run on the validation set by default, this can be turned on
3132
via the `--inference_on_val_set` flag.

InnerEye/ML/SSL/datamodules_and_datasets/datamodules.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ def __init__(self,
2929
num_workers: int = 6,
3030
batch_size: int = 32,
3131
seed: int = 42,
32+
drop_last: bool = True,
3233
*args: Any, **kwargs: Any) -> None:
3334
"""
3435
Wrapper around VisionDatamodule to load torchvision dataset into a pytorch-lightning module.
@@ -42,16 +43,17 @@ def __init__(self,
4243
:param val_transforms: transforms to use at validation time
4344
:param data_dir: data directory where to find the data
4445
:param val_split: proportion of training dataset to use for validation
45-
:param num_workers: number of processes for dataloaders.
46-
:param batch_size: batch size for training & validation.
46+
:param num_workers: number of processes for dataloaders
47+
:param batch_size: batch size for training & validation
4748
:param seed: random seed for dataset splitting
49+
:param drop_last: bool, if true it drops the last incomplete batch
4850
"""
4951
data_dir = data_dir if data_dir is not None else os.getcwd()
5052
super().__init__(data_dir=data_dir,
5153
val_split=val_split,
5254
num_workers=num_workers,
5355
batch_size=batch_size,
54-
drop_last=True,
56+
drop_last=drop_last,
5557
train_transforms=train_transforms,
5658
val_transforms=val_transforms,
5759
seed=seed,

InnerEye/ML/SSL/datamodules_and_datasets/transforms_utils.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -10,16 +10,17 @@
1010
from pl_bolts.models.self_supervised.simclr import SimCLRTrainDataTransform
1111
from yacs.config import CfgNode
1212

13-
from InnerEye.ML.augmentations.transform_pipeline import create_cxr_transforms_from_config
13+
from InnerEye.ML.augmentations.transform_pipeline import create_transforms_from_config
1414

1515

16-
def get_cxr_ssl_transforms(config: CfgNode,
17-
return_two_views_per_sample: bool,
18-
use_training_augmentations_for_validation: bool = False) -> Tuple[Any, Any]:
16+
def get_ssl_transforms_from_config(config: CfgNode,
17+
return_two_views_per_sample: bool,
18+
use_training_augmentations_for_validation: bool = False,
19+
expand_channels: bool = True) -> Tuple[Any, Any]:
1920
"""
2021
Returns training and validation transforms for CXR.
2122
Transformations are constructed in the following way:
22-
1. Construct the pipeline of augmentations in create_chest_xray_transform (e.g. resize, flip, affine) as defined
23+
1. Construct the pipeline of augmentations in create_transform_from_config (e.g. resize, flip, affine) as defined
2324
by the config.
2425
2. If we just want to construct the transformation pipeline for a classification model or for the linear evaluator
2526
of the SSL module, return this pipeline.
@@ -29,14 +30,18 @@ def get_cxr_ssl_transforms(config: CfgNode,
2930
3031
:param config: configuration defining which augmentations to apply as well as their intensities.
3132
:param return_two_views_per_sample: if True the resulting transforms will return two versions of each sample they
32-
are called on. If False, simply return one transformed version of the sample.
33+
are called on. If False, simply return one transformed version of the sample centered and cropped.
3334
:param use_training_augmentations_for_validation: If True, use augmentation at validation time too.
3435
This is required for SSL validation loss to be meaningful. If False, only apply basic processing step
3536
(no augmentations)
37+
:param expand_channels: if True the expand channel transformation from InnerEye.ML.augmentations.image_transforms
38+
will be added to the transformation passed through the config. This is needed for single channel images as CXR.
3639
"""
37-
train_transforms = create_cxr_transforms_from_config(config, apply_augmentations=True)
38-
val_transforms = create_cxr_transforms_from_config(config,
39-
apply_augmentations=use_training_augmentations_for_validation)
40+
train_transforms = create_transforms_from_config(config, apply_augmentations=True,
41+
expand_channels=expand_channels)
42+
val_transforms = create_transforms_from_config(config,
43+
apply_augmentations=use_training_augmentations_for_validation,
44+
expand_channels=expand_channels)
4045
if return_two_views_per_sample:
4146
train_transforms = DualViewTransformWrapper(train_transforms) # type: ignore
4247
val_transforms = DualViewTransformWrapper(val_transforms) # type: ignore

InnerEye/ML/SSL/lightning_containers/ssl_container.py

Lines changed: 26 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from InnerEye.ML.SSL.datamodules_and_datasets.datamodules import CombinedDataModule, InnerEyeVisionDataModule
1818
from InnerEye.ML.SSL.datamodules_and_datasets.transforms_utils import InnerEyeCIFARLinearHeadTransform, \
1919
InnerEyeCIFARTrainTransform, \
20-
get_cxr_ssl_transforms
20+
get_ssl_transforms_from_config
2121
from InnerEye.ML.SSL.encoders import get_encoder_output_dim
2222
from InnerEye.ML.SSL.lightning_modules.byol.byol_module import BYOLInnerEye
2323
from InnerEye.ML.SSL.lightning_modules.simclr_module import SimCLRInnerEye
@@ -96,6 +96,7 @@ class SSLContainer(LightningContainer):
9696
learning_rate_linear_head_during_ssl_training = param.Number(default=1e-4,
9797
doc="Learning rate for linear head training during "
9898
"SSL training.")
99+
drop_last = param.Boolean(default=True, doc="If True drops the last incomplete batch")
99100

100101
def setup(self) -> None:
101102
from InnerEye.ML.SSL.lightning_containers.ssl_image_classifier import SSLClassifierContainer
@@ -166,8 +167,8 @@ def create_model(self) -> LightningModule:
166167
f"Found {self.ssl_training_type.value}")
167168
model.hparams.update({'ssl_type': self.ssl_training_type.value,
168169
"num_classes": self.data_module.num_classes})
169-
self.encoder_output_dim = get_encoder_output_dim(model, self.data_module)
170170

171+
self.encoder_output_dim = get_encoder_output_dim(model, self.data_module)
171172
return model
172173

173174
def get_data_module(self) -> InnerEyeDataModuleTypes:
@@ -186,7 +187,7 @@ def _create_ssl_data_modules(self, is_ssl_encoder_module: bool) -> InnerEyeVisio
186187
"""
187188
Returns torch lightning data module for encoder or linear head
188189
189-
:param is_ssl_encoder_module: whether to return the data module for SSL training or for linear heard. If true,
190+
:param is_ssl_encoder_module: whether to return the data module for SSL training or for linear head. If true,
190191
:return transforms with two views per sample (batch like (img_v1, img_v2, label)). If False, return only one
191192
view per sample but also return the index of the sample in the dataset (to make sure we don't use twice the same
192193
batch in one training epoch (batch like (index, img_v1, label), as classifier dataloader expected to be shorter
@@ -209,7 +210,8 @@ def _create_ssl_data_modules(self, is_ssl_encoder_module: bool) -> InnerEyeVisio
209210
data_dir=str(datamodule_args.dataset_path),
210211
batch_size=batch_size_per_gpu,
211212
num_workers=self.num_workers,
212-
seed=self.random_seed)
213+
seed=self.random_seed,
214+
drop_last=self.drop_last)
213215
dm.prepare_data()
214216
dm.setup()
215217
return dm
@@ -223,25 +225,39 @@ def _get_transforms(self, augmentation_config: Optional[CfgNode],
223225
examples.
224226
:param dataset_name: name of the dataset, value has to be in SSLDatasetName, determines which transformation
225227
pipeline to return.
226-
:param is_ssl_encoder_module: if True the transformation pipeline will yield two version of the image it is
227-
applied on. If False, return only one transformation.
228+
:param is_ssl_encoder_module: if True the transformation pipeline will yield two versions of the image it is
229+
applied on and it applies the training transformations also at validation time. Note that if your transformation
230+
does not contain any randomness, the pipeline will return two identical copies. If False, it will return only one
231+
transformation.
228232
:return: training transformation pipeline and validation transformation pipeline.
229233
"""
230234
if dataset_name in [SSLDatasetName.RSNAKaggleCXR.value,
231235
SSLDatasetName.NIHCXR.value,
232236
SSLDatasetName.CheXpert.value,
233237
SSLDatasetName.Covid.value]:
234238
assert augmentation_config is not None
235-
train_transforms, val_transforms = get_cxr_ssl_transforms(augmentation_config,
236-
return_two_views_per_sample=is_ssl_encoder_module,
237-
use_training_augmentations_for_validation=is_ssl_encoder_module)
239+
train_transforms, val_transforms = get_ssl_transforms_from_config(
240+
augmentation_config,
241+
return_two_views_per_sample=is_ssl_encoder_module,
242+
use_training_augmentations_for_validation=is_ssl_encoder_module
243+
)
238244
elif dataset_name in [SSLDatasetName.CIFAR10.value, SSLDatasetName.CIFAR100.value]:
239245
train_transforms = \
240246
InnerEyeCIFARTrainTransform(32) if is_ssl_encoder_module else InnerEyeCIFARLinearHeadTransform(32)
241247
val_transforms = \
242248
InnerEyeCIFARTrainTransform(32) if is_ssl_encoder_module else InnerEyeCIFARLinearHeadTransform(32)
249+
elif augmentation_config:
250+
train_transforms, val_transforms = get_ssl_transforms_from_config(
251+
augmentation_config,
252+
return_two_views_per_sample=is_ssl_encoder_module,
253+
use_training_augmentations_for_validation=is_ssl_encoder_module,
254+
expand_channels=False,
255+
)
256+
logging.warning(f"Dataset {dataset_name} unknown. The config will be consumed by "
257+
f"get_ssl_transforms() to create the augmentation pipeline, make sure "
258+
f"the transformations in your configs are compatible. ")
243259
else:
244-
raise ValueError(f"Dataset {dataset_name} unknown.")
260+
raise ValueError(f"Dataset {dataset_name} unknown and no config has been passed.")
245261

246262
return train_transforms, val_transforms
247263

InnerEye/ML/augmentations/transform_pipeline.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -86,16 +86,22 @@ def __call__(self, data: ImageData) -> torch.Tensor:
8686
return self.transform_image(data)
8787

8888

89-
def create_cxr_transforms_from_config(config: CfgNode,
90-
apply_augmentations: bool) -> ImageTransformationPipeline:
89+
def create_transforms_from_config(config: CfgNode,
90+
apply_augmentations: bool,
91+
expand_channels: bool = True) -> ImageTransformationPipeline:
9192
"""
92-
Defines the image transformations pipeline used in Chest-Xray datasets. Can be used for other types of
93-
images data, type of augmentations to use and strength are expected to be defined in the config.
93+
Defines the image transformations pipeline from a config file. It has been designed for Chest X-Ray
94+
images but it can be used for other types of images data, type of augmentations to use and strength are
95+
expected to be defined in the config. The channel expansion is needed for gray images.
9496
:param config: config yaml file fixing strength and type of augmentation to apply
9597
:param apply_augmentations: if True return transformation pipeline with augmentations. Else,
9698
disable augmentations i.e. only resize and center crop the image.
99+
:param expand_channels: if True the expand channel transformation from InnerEye.ML.augmentations.image_transforms
100+
will be added to the transformation passed through the config. This is needed for single channel images as CXR.
97101
"""
98-
transforms: List[Any] = [ExpandChannels()]
102+
transforms: List[Any] = []
103+
if expand_channels:
104+
transforms.append(ExpandChannels())
99105
if apply_augmentations:
100106
if config.augmentation.use_random_affine:
101107
transforms.append(RandomAffine(

InnerEye/ML/configs/classification/CovidModel.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,8 @@
2323
from InnerEye.ML.SSL.lightning_containers.ssl_container import EncoderName
2424
from InnerEye.ML.SSL.lightning_modules.ssl_classifier_module import SSLClassifier
2525
from InnerEye.ML.SSL.utils import create_ssl_image_classifier, load_yaml_augmentation_config
26-
from InnerEye.ML.augmentations.transform_pipeline import create_cxr_transforms_from_config
26+
from InnerEye.ML.augmentations.transform_pipeline import create_transforms_from_config
27+
2728
from InnerEye.ML.common import ModelExecutionMode
2829

2930
from InnerEye.ML.configs.ssl.CXR_SSL_configs import path_linear_head_augmentation_cxr
@@ -137,9 +138,9 @@ def get_model_train_test_dataset_splits(self, dataset_df: pd.DataFrame) -> Datas
137138
def get_image_transform(self) -> ModelTransformsPerExecutionMode:
138139
config = load_yaml_augmentation_config(path_linear_head_augmentation_cxr)
139140
train_transforms = Compose(
140-
[DicomPreparation(), create_cxr_transforms_from_config(config, apply_augmentations=True)])
141+
[DicomPreparation(), create_transforms_from_config(config, apply_augmentations=True)])
141142
val_transforms = Compose(
142-
[DicomPreparation(), create_cxr_transforms_from_config(config, apply_augmentations=False)])
143+
[DicomPreparation(), create_transforms_from_config(config, apply_augmentations=False)])
143144

144145
return ModelTransformsPerExecutionMode(train=train_transforms,
145146
val=val_transforms,

Tests/ML/augmentations/test_transform_pipeline.py

Lines changed: 27 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import PIL
88
import pytest
99
import torch
10+
1011
from torchvision.transforms import (
1112
CenterCrop,
1213
ColorJitter,
@@ -46,7 +47,6 @@
4647
test_4d_scan_as_tensor = torch.ones([5, 4, *image_size]) * 255.0
4748
test_4d_scan_as_tensor[..., 10:15, 10:20] = 1
4849

49-
5050
@pytest.mark.parametrize("use_different_transformation_per_channel", [True, False])
5151
def test_torchvision_on_various_input(
5252
use_different_transformation_per_channel: bool,
@@ -136,17 +136,19 @@ def test_custom_tf_on_various_input(
136136
)
137137

138138

139-
def test_create_transform_pipeline_from_config() -> None:
139+
@pytest.mark.parametrize("expand_channels", [True, False])
140+
def test_create_transform_pipeline_from_config(expand_channels: bool) -> None:
140141
"""
141142
Tests that the pipeline returned by create_transform_pipeline_from_config returns the expected transformation.
142143
"""
144+
143145
transformation_pipeline = create_cxr_transforms_from_config(
144-
cxr_augmentation_config, apply_augmentations=True
146+
cxr_augmentation_config, apply_augmentations=True,
147+
expand_channels=expand_channels
145148
)
146149
fake_cxr_as_array = np.ones([256, 256]) * 255.0
147150
fake_cxr_as_array[100:150, 100:200] = 1
148-
fake_cxr_image = PIL.Image.fromarray(fake_cxr_as_array).convert("L")
149-
151+
150152
all_transforms = [
151153
ExpandChannels(),
152154
RandomAffine(degrees=180, translate=(0, 0), shear=40),
@@ -160,23 +162,28 @@ def test_create_transform_pipeline_from_config() -> None:
160162
AddGaussianNoise(std=0.05, p_apply=0.5),
161163
]
162164

165+
if expand_channels:
166+
all_transforms.insert(0, ExpandChannels())
167+
# expand channels is used for single-channel input images
168+
fake_image = PIL.Image.fromarray(fake_cxr_as_array).convert("L")
169+
# In the pipeline the image is converted to tensor before applying the transformations. Do the same here.
170+
image = ToTensor()(fake_image).reshape([1, 1, 256, 256])
171+
else:
172+
fake_3d_array = np.dstack([fake_cxr_as_array, fake_cxr_as_array, fake_cxr_as_array])
173+
fake_image = PIL.Image.fromarray(fake_3d_array.astype(np.uint8)).convert("RGB")
174+
# In the pipeline the image is converted to tensor before applying the transformations. Do the same here.
175+
image = ToTensor()(fake_image).reshape([1, 3, 256, 256])
176+
163177
np.random.seed(3)
164178
torch.manual_seed(3)
165179
random.seed(3)
166-
167-
transformed_image = transformation_pipeline(fake_cxr_image)
180+
transformed_image = transformation_pipeline(fake_image)
168181
assert isinstance(transformed_image, torch.Tensor)
169-
# Expected pipeline
170-
image = np.ones([256, 256]) * 255.0
171-
image[100:150, 100:200] = 1
172-
image = PIL.Image.fromarray(image).convert("L")
173-
# In the pipeline the image is converted to tensor before applying the transformations. Do the same here.
174-
image = ToTensor()(image).reshape([1, 1, 256, 256])
175182

183+
# Expected pipeline
176184
np.random.seed(3)
177185
torch.manual_seed(3)
178186
random.seed(3)
179-
180187
expected_transformed = image
181188
for t in all_transforms:
182189
expected_transformed = t(expected_transformed)
@@ -187,11 +194,15 @@ def test_create_transform_pipeline_from_config() -> None:
187194

188195
# Test the evaluation pipeline
189196
transformation_pipeline = create_cxr_transforms_from_config(
190-
cxr_augmentation_config, apply_augmentations=False
197+
cxr_augmentation_config, apply_augmentations=False,
198+
expand_channels=expand_channels,
191199
)
192200
transformed_image = transformation_pipeline(image)
193201
assert isinstance(transformed_image, torch.Tensor)
194-
all_transforms = [ExpandChannels(), Resize(size=256), CenterCrop(size=224)]
202+
all_transforms = [Resize(size=256), CenterCrop(size=224)]
203+
if expand_channels:
204+
all_transforms.insert(0, ExpandChannels())
205+
195206
expected_transformed = image
196207
for t in all_transforms:
197208
expected_transformed = t(expected_transformed)

0 commit comments

Comments
 (0)