Skip to content
This repository was archived by the owner on Mar 21, 2024. It is now read-only.
41 changes: 30 additions & 11 deletions InnerEye/ML/configs/classification/CovidHierarchicalModel.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import codecs
import logging
import pickle
import random
import math
from pathlib import Path

from typing import Any, Callable
Expand Down Expand Up @@ -62,7 +64,6 @@ class CovidHierarchicalModel(ScalarModelBase):
"is assumed to contain unique ids.")

def __init__(self, covid_dataset_id: str = COVID_DATASET_ID, **kwargs: Any):
learning_rate = 1e-5 if self.use_pretrained_model else 1e-4
super().__init__(target_names=['CVX03vs12', 'CVX0vs3', 'CVX1vs2'],
loss_type=ScalarLoss.CustomClassification,
class_names=['CVX0', 'CVX1', 'CVX2', 'CVX3'],
Expand All @@ -81,10 +82,13 @@ def __init__(self, covid_dataset_id: str = COVID_DATASET_ID, **kwargs: Any):
num_epochs=50,
l_rate_scheduler=LRSchedulerType.Step,
l_rate_step_gamma=1.0,
l_rate=learning_rate,
l_rate_multi_step_milestones=None,
**kwargs)
self.num_classes = 3

def validate(self) -> None:
self.l_rate = 1e-5 if self.use_pretrained_model else 1e-4
super().validate()
if not self.use_pretrained_model and self.freeze_encoder:
raise ValueError("No encoder to freeze when training from scratch. You requested training from scratch and"
"encoder freezing.")
Expand All @@ -94,15 +98,30 @@ def should_generate_multilabel_report(self) -> bool:

def get_model_train_test_dataset_splits(self, dataset_df: pd.DataFrame) -> DatasetSplits:
if self.test_set_ids_csv:
test_df = pd.read_csv(self.local_dataset / self.test_set_ids_csv)
in_test_set = dataset_df.series.isin(test_df.series)
train_ids = dataset_df.series[~in_test_set].values
test_ids = dataset_df.series[in_test_set].values
num_val_samples = 400
val_ids = train_ids[:num_val_samples]
train_ids = train_ids[num_val_samples:]
return DatasetSplits.from_subject_ids(dataset_df, train_ids=train_ids, val_ids=val_ids, test_ids=test_ids,
subject_column="series", group_column="subject")
test_set_ids_csv = self.local_dataset / self.test_set_ids_csv
test_series = pd.read_csv(test_set_ids_csv).series

all_series = dataset_df.series.values
check_all_test_series = all(test_series.isin(all_series))
if not check_all_test_series:
raise ValueError(f"Not all test series from {test_set_ids_csv} were found in the dataset.")

test_set_subjects = dataset_df[dataset_df.series.isin(test_series)].subject.values
train_and_val_series = dataset_df[~dataset_df.subject.isin(test_set_subjects)].series.values
random.seed(42)
random.shuffle(train_and_val_series)
num_val_samples = math.floor(len(train_and_val_series) / 9)
val_series = train_and_val_series[:num_val_samples]
train_series = train_and_val_series[num_val_samples:]

logging.info(f"Dropped {len(all_series) - (len(test_series) + len(train_and_val_series))} series "
f"due to subject overlap with test set.")
return DatasetSplits.from_subject_ids(dataset_df,
train_ids=train_series,
val_ids=val_series,
test_ids=test_series,
subject_column="series",
group_column="subject")
else:
return DatasetSplits.from_proportions(dataset_df,
proportion_train=0.8,
Expand Down