diff --git a/InnerEye/ML/configs/classification/CovidHierarchicalModel.py b/InnerEye/ML/configs/classification/CovidHierarchicalModel.py index fd3c9d782..d7b99c343 100644 --- a/InnerEye/ML/configs/classification/CovidHierarchicalModel.py +++ b/InnerEye/ML/configs/classification/CovidHierarchicalModel.py @@ -1,6 +1,8 @@ import codecs import logging import pickle +import random +import math from pathlib import Path from typing import Any, Callable @@ -62,7 +64,6 @@ class CovidHierarchicalModel(ScalarModelBase): "is assumed to contain unique ids.") def __init__(self, covid_dataset_id: str = COVID_DATASET_ID, **kwargs: Any): - learning_rate = 1e-5 if self.use_pretrained_model else 1e-4 super().__init__(target_names=['CVX03vs12', 'CVX0vs3', 'CVX1vs2'], loss_type=ScalarLoss.CustomClassification, class_names=['CVX0', 'CVX1', 'CVX2', 'CVX3'], @@ -81,10 +82,13 @@ def __init__(self, covid_dataset_id: str = COVID_DATASET_ID, **kwargs: Any): num_epochs=50, l_rate_scheduler=LRSchedulerType.Step, l_rate_step_gamma=1.0, - l_rate=learning_rate, l_rate_multi_step_milestones=None, **kwargs) self.num_classes = 3 + + def validate(self) -> None: + self.l_rate = 1e-5 if self.use_pretrained_model else 1e-4 + super().validate() if not self.use_pretrained_model and self.freeze_encoder: raise ValueError("No encoder to freeze when training from scratch. You requested training from scratch and" "encoder freezing.") @@ -94,15 +98,30 @@ def should_generate_multilabel_report(self) -> bool: def get_model_train_test_dataset_splits(self, dataset_df: pd.DataFrame) -> DatasetSplits: if self.test_set_ids_csv: - test_df = pd.read_csv(self.local_dataset / self.test_set_ids_csv) - in_test_set = dataset_df.series.isin(test_df.series) - train_ids = dataset_df.series[~in_test_set].values - test_ids = dataset_df.series[in_test_set].values - num_val_samples = 400 - val_ids = train_ids[:num_val_samples] - train_ids = train_ids[num_val_samples:] - return DatasetSplits.from_subject_ids(dataset_df, train_ids=train_ids, val_ids=val_ids, test_ids=test_ids, - subject_column="series", group_column="subject") + test_set_ids_csv = self.local_dataset / self.test_set_ids_csv + test_series = pd.read_csv(test_set_ids_csv).series + + all_series = dataset_df.series.values + check_all_test_series = all(test_series.isin(all_series)) + if not check_all_test_series: + raise ValueError(f"Not all test series from {test_set_ids_csv} were found in the dataset.") + + test_set_subjects = dataset_df[dataset_df.series.isin(test_series)].subject.values + train_and_val_series = dataset_df[~dataset_df.subject.isin(test_set_subjects)].series.values + random.seed(42) + random.shuffle(train_and_val_series) + num_val_samples = math.floor(len(train_and_val_series) / 9) + val_series = train_and_val_series[:num_val_samples] + train_series = train_and_val_series[num_val_samples:] + + logging.info(f"Dropped {len(all_series) - (len(test_series) + len(train_and_val_series))} series " + f"due to subject overlap with test set.") + return DatasetSplits.from_subject_ids(dataset_df, + train_ids=train_series, + val_ids=val_series, + test_ids=test_series, + subject_column="series", + group_column="subject") else: return DatasetSplits.from_proportions(dataset_df, proportion_train=0.8,