diff --git a/CHANGELOG.md b/CHANGELOG.md index 8cd5c5575..68e1707a7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -21,6 +21,7 @@ loss. GPU utilization via Lightning's `GpuStatsMonitor`, switch `monitor_loading` to check batch loading times via `BatchTimeCallback`, and `pl_profiler` to turn on the Lightning profiler (`simple`, `advanced`, or `pytorch`) - ([#544](https://github.com/microsoft/InnerEye-DeepLearning/pull/544)) Add documentation for segmentation model evaluation. +- ([#637](https://github.com/microsoft/InnerEye-DeepLearning/pull/637)) Add option to encode in chunks and to load pre-cached dataset in CPU or GPU in the histo pipeline. - ([#465](https://github.com/microsoft/InnerEye-DeepLearning/pull/465/)) Adding ability to run segmentation inference module on test data with partial ground truth files. (Also [522](https://github.com/microsoft/InnerEye-DeepLearning/pull/522).) - ([#502](https://github.com/microsoft/InnerEye-DeepLearning/pull/502)) More flags for fine control of when to run inference. @@ -133,7 +134,7 @@ in inference-only runs when using lightning containers. ### Deprecated -- ([#633](https://github.com/microsoft/InnerEye-DeepLearning/pull/633)) Model fields `recovery_checkpoint_save_interval` and `recovery_checkpoints_save_last_k` have been retired. +- ([#633](https://github.com/microsoft/InnerEye-DeepLearning/pull/633)) Model fields `recovery_checkpoint_save_interval` and `recovery_checkpoints_save_last_k` have been retired. Recovery checkpoint handling is now controlled by `autosave_every_n_val_epochs`. diff --git a/InnerEye/ML/Histopathology/datamodules/base_module.py b/InnerEye/ML/Histopathology/datamodules/base_module.py index cb67c5506..b20d97f8c 100644 --- a/InnerEye/ML/Histopathology/datamodules/base_module.py +++ b/InnerEye/ML/Histopathology/datamodules/base_module.py @@ -3,7 +3,7 @@ # Licensed under the MIT License (MIT). See LICENSE in the repo root for license information. # ------------------------------------------------------------------------------------------ -import pickle +import torch from enum import Enum from pathlib import Path from typing import Any, Callable, Optional, Sequence, Tuple, Union @@ -23,13 +23,17 @@ class CacheMode(Enum): MEMORY = 'memory' DISK = 'disk' - +class CacheLocation(Enum): + NONE = 'none' + CPU = 'cpu' + SAME = 'same' class TilesDataModule(LightningDataModule): """Base class to load the tiles of a dataset as train, val, test sets""" def __init__(self, root_path: Path, max_bag_size: int = 0, batch_size: int = 1, seed: Optional[int] = None, transform: Optional[Callable] = None, - cache_mode: CacheMode = CacheMode.NONE, save_precache: bool = False, + cache_mode: CacheMode = CacheMode.NONE, + precache_location: CacheLocation = CacheLocation.NONE, cache_dir: Optional[Path] = None, number_of_cross_validation_splits: int = 0, cross_validation_split_index: int = 0) -> None: @@ -46,19 +50,24 @@ def __init__(self, root_path: Path, max_bag_size: int = 0, batch_size: int = 1, :param cache_mode: The type of caching to perform, i.e. whether the results of all transforms up to the first randomised one should be computed only once and reused in subsequent iterations: - - `MEMORY`: the entire transformed dataset is kept in memory for fastest access; - - `DISK`: each transformed sample is saved to disk and loaded on-demand; - - `NONE` (default): no caching is performed. - :param save_precache: Whether to pre-cache the entire transformed dataset upfront and save + - `MEMORY`: MONAI CacheDataset is used, the entire transformed dataset is kept in memory for fastest access; + - `DISK`: MONAI PersistentDataset is used, each transformed sample is saved to disk and loaded on-demand; + - `NONE` (default): standard MONAI dataset is used, no caching is performed. + :param precache_location: Whether to pre-cache the entire transformed dataset upfront and save it to disk. This is done once in `prepare_data()` only on the local rank-0 process, so - multiple processes can afterwards access the same cache without contention in DDP settings. + multiple processes can afterwards access the same cache without contention in DDP settings. This parameter also allow to + choose if the cache will be re-loaded into CPU or GPU memory: + - `NONE (default)`: no pre-cache is performed; + - `CPU`: each transformed sample is saved to disk and, if cache_mode is `MEMORY`, reloaded into CPU; + - `SAME`: each transformed sample is saved to disk and, if cache_mode is `MEMORY`, reloaded on the same device it was saved from; + If cache_mode is `DISK` precache_location `CPU` and `GPU` are equivalent. :param cache_dir: The directory onto which to cache data if caching is enabled. :param number_of_cross_validation_splits: Number of folds to perform. :param cross_validation_split_index: Index of the cross validation split to be performed. """ - if save_precache and cache_mode is CacheMode.NONE: + if precache_location is not CacheLocation.NONE and cache_mode is CacheMode.NONE: raise ValueError("Can only pre-cache if caching is enabled") - if save_precache and cache_dir is None: + if precache_location is not CacheLocation.NONE and cache_dir is None: raise ValueError("A cache directory is required for pre-caching") if cache_mode is CacheMode.DISK and cache_dir is None: raise ValueError("A cache directory is required for on-disk caching") @@ -68,7 +77,7 @@ def __init__(self, root_path: Path, max_bag_size: int = 0, batch_size: int = 1, self.max_bag_size = max_bag_size self.transform = transform self.cache_mode = cache_mode - self.save_precache = save_precache + self.precache_location = precache_location self.cache_dir = cache_dir self.batch_size = batch_size self.number_of_cross_validation_splits = number_of_cross_validation_splits @@ -82,7 +91,7 @@ def get_splits(self) -> Tuple[TilesDataset, TilesDataset, TilesDataset]: raise NotImplementedError def prepare_data(self) -> None: - if self.save_precache: + if self.precache_location != CacheLocation.NONE: self._load_dataset(self.train_dataset, stage='train', shuffle=True) self._load_dataset(self.val_dataset, stage='val', shuffle=True) self._load_dataset(self.test_dataset, stage='test', shuffle=True) @@ -90,14 +99,21 @@ def prepare_data(self) -> None: def _dataset_pickle_path(self, stage: str) -> Optional[Path]: if self.cache_dir is None: return None - return self.cache_dir / f"{stage}_dataset.pkl" + return self.cache_dir / f"{stage}_dataset.pt" def _load_dataset(self, tiles_dataset: TilesDataset, stage: str, shuffle: bool) -> Dataset: dataset_pickle_path = self._dataset_pickle_path(stage) - if dataset_pickle_path and dataset_pickle_path.exists(): + if dataset_pickle_path and dataset_pickle_path.is_file(): + if self.precache_location == CacheLocation.CPU: + memory_location = torch.device('cpu') + print(f"Loading dataset from {dataset_pickle_path} into {memory_location}") + else: + # by default torch.load will reload on the same device it was saved from + memory_location = None # type: ignore + with dataset_pickle_path.open('rb') as f: - return pickle.load(f) + return torch.load(f, map_location=memory_location) generator = _create_generator(self.seed) bag_dataset = BagDataset(tiles_dataset, # type: ignore @@ -112,10 +128,11 @@ def _load_dataset(self, tiles_dataset: TilesDataset, stage: str, shuffle: bool) transformed_bag_dataset = self._get_transformed_dataset(bag_dataset, transform) # type: ignore generator.set_state(generator_state) + # Dataset is saved if cache_dir is True, regardless of CacheMode if dataset_pickle_path: dataset_pickle_path.parent.mkdir(parents=True, exist_ok=True) with dataset_pickle_path.open('wb') as f: - pickle.dump(transformed_bag_dataset, f) + torch.save(transformed_bag_dataset, f) return transformed_bag_dataset @@ -125,7 +142,7 @@ def _get_transformed_dataset(self, base_dataset: BagDataset, dataset = CacheDataset(base_dataset, transform, num_workers=1) # type: ignore elif self.cache_mode is CacheMode.DISK: dataset = PersistentDataset(base_dataset, transform, cache_dir=self.cache_dir) # type: ignore - if self.save_precache: + if self.precache_location != CacheLocation.NONE: import tqdm # TODO: Make optional for i in tqdm.trange(len(dataset), desc="Loading dataset"): diff --git a/InnerEye/ML/Histopathology/models/deepmil.py b/InnerEye/ML/Histopathology/models/deepmil.py index 827d23a03..14ca78c00 100644 --- a/InnerEye/ML/Histopathology/models/deepmil.py +++ b/InnerEye/ML/Histopathology/models/deepmil.py @@ -51,7 +51,7 @@ def __init__(self, weight_decay: float = 1e-4, adam_betas: Tuple[float, float] = (0.9, 0.99), verbose: bool = False, - slide_dataset: SlidesDataset = None, + slide_dataset: SlidesDataset = None, tile_size: int = 224, level: int = 1) -> None: """ @@ -324,13 +324,13 @@ def test_epoch_end(self, outputs: List[Dict[str, Any]]) -> None: # type: ignore if self.slide_dataset is not None: slide_dict = mi.first_true(self.slide_dataset, pred=lambda entry: entry[SlideKey.SLIDE_ID] == slide) # type: ignore - _ = load_image_dict(slide_dict, level=self.level, margin=0) # type: ignore + _ = load_image_dict(slide_dict, level=self.level, margin=0) # type: ignore slide_image = slide_dict[SlideKey.IMAGE] location_bbox = slide_dict[SlideKey.LOCATION] fig = plot_slide(slide_image=slide_image, scale=1.0) self.save_figure(fig=fig, figpath=Path(key_folder_path, f'{slide}_thumbnail.png')) - fig = plot_heatmap_overlay(slide=slide, slide_image=slide_image, results=results, + fig = plot_heatmap_overlay(slide=slide, slide_image=slide_image, results=results, location_bbox=location_bbox, tile_size=self.tile_size, level=self.level) self.save_figure(fig=fig, figpath=Path(key_folder_path, f'{slide}_heatmap.png')) diff --git a/InnerEye/ML/Histopathology/models/transforms.py b/InnerEye/ML/Histopathology/models/transforms.py index e50d088e0..1d61fdcb3 100644 --- a/InnerEye/ML/Histopathology/models/transforms.py +++ b/InnerEye/ML/Histopathology/models/transforms.py @@ -6,8 +6,9 @@ from pathlib import Path from typing import Mapping, Sequence, Union -import PIL.Image import torch +import numpy as np +import PIL from monai.config.type_definitions import KeysCollection from monai.transforms.transform import MapTransform from torchvision.transforms.functional import to_tensor @@ -19,7 +20,9 @@ def load_pil_image(image_path: PathOrString) -> PIL.Image.Image: """Load a PIL image in RGB format from the given path""" - return PIL.Image.open(image_path).convert('RGB') + with PIL.PngImagePlugin.PngImageFile(image_path) as pil_png: + image = np.asarray(pil_png) + return image def load_image_as_tensor(image_path: PathOrString) -> torch.Tensor: @@ -86,19 +89,34 @@ class EncodeTilesBatchd(MapTransform): def __init__(self, keys: KeysCollection, encoder: TileEncoder, - allow_missing_keys: bool = False) -> None: + allow_missing_keys: bool = False, + chunk_size: int = 0) -> None: """ :param keys: Key(s) for the image path(s) in the input dictionary. :param encoder: The tile encoder to use for feature extraction. :param allow_missing_keys: If `False` (default), raises an exception when an input dictionary is missing any of the specified keys. + :param chunk_size: if > 0, extracts features in chunks of size chunk_size. """ super().__init__(keys, allow_missing_keys) self.encoder = encoder + self.chunk_size = chunk_size @torch.no_grad() def _encode_tiles(self, images: torch.Tensor) -> torch.Tensor: device = next(self.encoder.parameters()).device + if self.chunk_size > 0: + embeddings = [] + chunks = torch.split(images, self.chunk_size) + # TODO parallelize encoding - keep metadata and images aligned + for chunk in chunks: + chunk_embeddings = self._encode_images(chunk, device) + embeddings.append(chunk_embeddings) + return torch.cat(embeddings) + else: + return self._encode_images(images, device) + + def _encode_images(self, images: torch.Tensor, device: torch.device) -> torch.Tensor: images = images.to(device) embeddings = self.encoder(images) del images diff --git a/InnerEye/ML/Histopathology/utils/metrics_utils.py b/InnerEye/ML/Histopathology/utils/metrics_utils.py index b6f49f081..93f9a4a22 100644 --- a/InnerEye/ML/Histopathology/utils/metrics_utils.py +++ b/InnerEye/ML/Histopathology/utils/metrics_utils.py @@ -8,7 +8,7 @@ import matplotlib.pyplot as plt from math import ceil import numpy as np -import matplotlib.patches as patches +import matplotlib.patches as patches import matplotlib.collections as collection from InnerEye.ML.Histopathology.models.transforms import load_pil_image @@ -117,7 +117,7 @@ def plot_slide(slide_image: np.ndarray, scale: float) -> plt.figure: return fig -def plot_heatmap_overlay(slide: str, +def plot_heatmap_overlay(slide: str, slide_image: np.ndarray, results: Dict[str, List[Any]], location_bbox: List[int], diff --git a/InnerEye/ML/configs/histo_configs/classification/BaseMIL.py b/InnerEye/ML/configs/histo_configs/classification/BaseMIL.py index 521711b17..4328de36b 100644 --- a/InnerEye/ML/configs/histo_configs/classification/BaseMIL.py +++ b/InnerEye/ML/configs/histo_configs/classification/BaseMIL.py @@ -8,7 +8,7 @@ their datamodules and configure experiment-specific parameters. """ from pathlib import Path -from typing import Type +from typing import Type # noqa import param from torch import nn @@ -17,7 +17,7 @@ from health_ml.networks.layers.attention_layers import AttentionLayer, GatedAttentionLayer from InnerEye.ML.lightning_container import LightningContainer from InnerEye.ML.Histopathology.datasets.base_dataset import SlidesDataset -from InnerEye.ML.Histopathology.datamodules.base_module import CacheMode, TilesDataModule +from InnerEye.ML.Histopathology.datamodules.base_module import CacheMode, CacheLocation, TilesDataModule from InnerEye.ML.Histopathology.models.deepmil import DeepMILModule from InnerEye.ML.Histopathology.models.encoders import (HistoSSLEncoder, IdentityEncoder, ImageNetEncoder, ImageNetSimCLREncoder, @@ -44,8 +44,12 @@ class BaseMIL(LightningContainer): cache_mode: CacheMode = param.ClassSelector(default=CacheMode.MEMORY, class_=CacheMode, doc="The type of caching to perform: " "'memory' (default), 'disk', or 'none'.") - save_precache: bool = param.Boolean(True, doc="Whether to pre-cache the entire transformed " - "dataset upfront and save it to disk.") + precache_location: str = param.ClassSelector(default=CacheLocation.NONE, class_=CacheLocation, + doc="Whether to pre-cache the entire transformed dataset upfront " + "and save it to disk and if re-load in cpu or gpu. Options:" + "`none` (default),`cpu`, `gpu`") + encoding_chunk_size: int = param.Integer(0, doc="If > 0 performs encoding in chunks, by loading" + "enconding_chunk_size tiles per chunk") # local_dataset (used as data module root_path) is declared in DatasetParams superclass @property diff --git a/InnerEye/ML/configs/histo_configs/classification/DeepSMILEPanda.py b/InnerEye/ML/configs/histo_configs/classification/DeepSMILEPanda.py index 59a617b74..c72dc3f66 100644 --- a/InnerEye/ML/configs/histo_configs/classification/DeepSMILEPanda.py +++ b/InnerEye/ML/configs/histo_configs/classification/DeepSMILEPanda.py @@ -39,6 +39,9 @@ def __init__(self, **kwargs: Any) -> None: default_kwargs = dict( # declared in BaseMIL: pooling_type=GatedAttentionLayer.__name__, + # average number of tiles is 56 for PANDA + encoding_chunk_size=60, + # declared in DatasetParams: local_dataset=Path("/tmp/datasets/PANDA_tiles"), azure_dataset_id="PANDA_tiles", @@ -48,9 +51,11 @@ def __init__(self, **kwargs: Any) -> None: # declared in TrainerParams: num_epochs=200, # use_mixed_precision = True, + # declared in WorkflowParams: number_of_cross_validation_splits=5, cross_validation_split_index=0, + # declared in OptimizerParams: l_rate=5e-4, weight_decay=1e-4, @@ -101,7 +106,7 @@ def get_data_module(self) -> PandaTilesDataModule: transform = Compose( [ LoadTilesBatchd(image_key, progress=True), - EncodeTilesBatchd(image_key, self.encoder), + EncodeTilesBatchd(image_key, self.encoder, chunk_size=self.encoding_chunk_size), ] ) return PandaTilesDataModule( @@ -110,7 +115,7 @@ def get_data_module(self) -> PandaTilesDataModule: batch_size=self.batch_size, transform=transform, cache_mode=self.cache_mode, - save_precache=self.save_precache, + precache_location=self.precache_location, cache_dir=self.cache_dir, number_of_cross_validation_splits=self.number_of_cross_validation_splits, cross_validation_split_index=self.cross_validation_split_index, @@ -134,7 +139,7 @@ def create_model(self) -> DeepMILModule: tile_size=self.tile_size, level=self.level) - def get_slide_dataset(self) -> PandaDataset: + def get_slide_dataset(self) -> PandaDataset: return PandaDataset(root=self.extra_local_dataset_paths[0]) # type: ignore def get_callbacks(self) -> List[Callback]: @@ -162,8 +167,7 @@ def get_path_to_best_checkpoint(self) -> Path: if checkpoint_path.is_file(): return checkpoint_path - raise ValueError("Path to best checkpoint not found") - + raise ValueError("Path to best checkpoint not found") class PandaImageNetMIL(DeepSMILEPanda): def __init__(self, **kwargs: Any) -> None: diff --git a/InnerEye/ML/utils/io_util.py b/InnerEye/ML/utils/io_util.py index c1b6b91f7..d25024a18 100644 --- a/InnerEye/ML/utils/io_util.py +++ b/InnerEye/ML/utils/io_util.py @@ -461,7 +461,7 @@ def load_labels_from_dataset_source(dataset_source: PatientDatasetSource, check_ def load_image(path: PathOrString, image_type: Optional[Type] = float) -> ImageWithHeader: """ - Loads an image with extension numpy or nifti + Loads an image with extension numpy or nifti or png For HDF5 path suffix For images || For segmentation binary || diff --git a/Tests/ML/histopathology/datamodules/test_datamodule_caching.py b/Tests/ML/histopathology/datamodules/test_datamodule_caching.py index 41325f401..7bd4ac2e5 100644 --- a/Tests/ML/histopathology/datamodules/test_datamodule_caching.py +++ b/Tests/ML/histopathology/datamodules/test_datamodule_caching.py @@ -13,7 +13,7 @@ import torch from torch.utils.data import DataLoader -from InnerEye.ML.Histopathology.datamodules.base_module import CacheMode, TilesDataModule +from InnerEye.ML.Histopathology.datamodules.base_module import CacheMode, CacheLocation, TilesDataModule from InnerEye.ML.Histopathology.datasets.base_dataset import TilesDataset @@ -50,6 +50,7 @@ class MockTilesDataset(TilesDataset): def generate_mock_dataset_df(n_slides: int, n_tiles: int, n_classes: int) -> pd.DataFrame: + np.random.seed(1234) slide_ids = np.random.randint(n_slides, size=n_tiles) slide_labels = np.random.randint(n_classes, size=n_slides) tile_labels = slide_labels[slide_ids] @@ -91,14 +92,14 @@ def mock_data_dir(tmp_path: Path) -> Path: df.to_csv(csv_path, index=False) return csv_dir -def _get_datamodule(cache_mode: CacheMode, save_precache: bool, +def _get_datamodule(cache_mode: CacheMode, precache_location: CacheLocation, cache_dir_provided: bool, data_dir: Path) -> TilesDataModule: - if (cache_mode is CacheMode.NONE and save_precache) \ + if (cache_mode is CacheMode.NONE and precache_location is not CacheLocation.NONE) \ or (cache_mode is CacheMode.DISK and not cache_dir_provided) \ - or (save_precache and not cache_dir_provided): + or (precache_location is not CacheLocation.NONE and not cache_dir_provided): pytest.skip("Unsupported combination of caching arguments") - cache_dir = data_dir / f"datamodule_cache_{cache_mode.value}" if cache_dir_provided else None + cache_dir = data_dir / f"datamodule_cache_{cache_mode.value}_{precache_location.value}" if cache_dir_provided else None if cache_dir is not None and cache_dir.exists(): shutil.rmtree(cache_dir) @@ -108,18 +109,18 @@ def _get_datamodule(cache_mode: CacheMode, save_precache: bool, seed=0, batch_size=2, cache_mode=cache_mode, - save_precache=save_precache, + precache_location=precache_location, cache_dir=cache_dir) @pytest.mark.parametrize('cache_mode', [CacheMode.MEMORY, CacheMode.DISK, CacheMode.NONE]) -@pytest.mark.parametrize('save_precache', [True, False]) +@pytest.mark.parametrize('precache_location', [CacheLocation.NONE, CacheLocation.CPU, CacheLocation.SAME]) @pytest.mark.parametrize('cache_dir_provided', [True, False]) -def test_caching_consistency(mock_data_dir: Path, cache_mode: CacheMode, save_precache: bool, +def test_caching_consistency(mock_data_dir: Path, cache_mode: CacheMode, precache_location: CacheLocation, cache_dir_provided: bool) -> None: # Compare two dataloaders from the same datamodule datamodule = _get_datamodule(cache_mode=cache_mode, - save_precache=save_precache, + precache_location=precache_location, cache_dir_provided=cache_dir_provided, data_dir=mock_data_dir) datamodule.prepare_data() @@ -130,14 +131,14 @@ def test_caching_consistency(mock_data_dir: Path, cache_mode: CacheMode, save_pr # Compare datamodules reusing the same cache datamodule = _get_datamodule(cache_mode=cache_mode, - save_precache=save_precache, + precache_location=precache_location, cache_dir_provided=cache_dir_provided, data_dir=mock_data_dir) datamodule.prepare_data() train_dataloader = datamodule.train_dataloader() reloaded_datamodule = _get_datamodule(cache_mode=cache_mode, - save_precache=save_precache, + precache_location=precache_location, cache_dir_provided=cache_dir_provided, data_dir=mock_data_dir) reloaded_datamodule.prepare_data() @@ -146,13 +147,18 @@ def test_caching_consistency(mock_data_dir: Path, cache_mode: CacheMode, save_pr compare_dataloaders(train_dataloader, reloaded_train_dataloader) -@pytest.mark.parametrize('cache_mode', [CacheMode.MEMORY, CacheMode.DISK, CacheMode.NONE]) -@pytest.mark.parametrize('save_precache', [True, False]) -@pytest.mark.parametrize('cache_dir_provided', [True, False]) -def test_tile_id_coverage(mock_data_dir: Path, cache_mode: CacheMode, save_precache: bool, +@pytest.mark.parametrize('cache_mode, precache_location, cache_dir_provided', + [(CacheMode.DISK, CacheLocation.SAME, True), + (CacheMode.DISK, CacheLocation.CPU, True), + (CacheMode.MEMORY, CacheLocation.SAME, True), + (CacheMode.MEMORY, CacheLocation.CPU, True), + (CacheMode.MEMORY, CacheLocation.NONE, False), + (CacheMode.NONE, CacheLocation.NONE, False) + ]) +def test_tile_id_coverage(mock_data_dir: Path, cache_mode: CacheMode, precache_location: CacheLocation, cache_dir_provided: bool) -> None: datamodule = _get_datamodule(cache_mode=cache_mode, - save_precache=save_precache, + precache_location=precache_location, cache_dir_provided=cache_dir_provided, data_dir=mock_data_dir) datamodule.prepare_data() diff --git a/Tests/ML/histopathology/models/test_transforms.py b/Tests/ML/histopathology/models/test_transforms.py index ba938169a..17aa8c60b 100644 --- a/Tests/ML/histopathology/models/test_transforms.py +++ b/Tests/ML/histopathology/models/test_transforms.py @@ -126,8 +126,10 @@ def test_cached_loading(tmp_path: Path) -> None: @pytest.mark.skipif(not os.path.isdir(TCGA_CRCK_DATASET_DIR), reason="TCGA-CRCk tiles dataset is unavailable") -@pytest.mark.parametrize('use_gpu', [False, True]) -def test_encode_tiles(tmp_path: Path, use_gpu: bool) -> None: +@pytest.mark.parametrize('use_gpu , chunk_size', + [(False, 0), (False, 2), (True, 0), (True, 2)] + ) +def test_encode_tiles(tmp_path: Path, use_gpu: bool, chunk_size: int) -> None: tiles_dataset = TcgaCrck_TilesDataset(TCGA_CRCK_DATASET_DIR) image_key = tiles_dataset.IMAGE_COLUMN max_bag_size = 5 @@ -138,7 +140,7 @@ def test_encode_tiles(tmp_path: Path, use_gpu: bool) -> None: if use_gpu: encoder.cuda() - encode_transform = EncodeTilesBatchd(image_key, encoder) + encode_transform = EncodeTilesBatchd(image_key, encoder, chunk_size=chunk_size) transform = Compose([LoadTilesBatchd(image_key), encode_transform]) dataset = Dataset(bagged_dataset, transform=transform) # type: ignore sample = dataset[0]