Skip to content

checkpoint saving stuck when use multiple GPUs #6495

@sun-peach

Description

@sun-peach

🐛 Bug

When I use multiple GPUs, the model saving step will be stuck, while it works perfectly when I use only one GPU.

Please reproduce using the BoringModel

class Spectrogram_based(Conditional_Source_Separation, metaclass=ABCMeta):

    @staticmethod
    def add_model_specific_args(parent_parser):
        parser = ArgumentParser(parents=[parent_parser], add_help=False)

        parser.add_argument('--n_fft', type=int, default=2048)
        parser.add_argument('--hop_length', type=int, default=1024)
        parser.add_argument('--num_frame', type=int, default=128)
        parser.add_argument('--spec_type', type=str, default='complex')
        parser.add_argument('--spec_est_mode', type=str, default='mapping')

        parser.add_argument('--train_loss', type=str, default='spec_mse')
        parser.add_argument('--val_loss', type=str, default='raw_l1')
        parser.add_argument('--unfreeze_stft_from', type=int, default=-1)  # -1 means never.

        return Conditional_Source_Separation.add_model_specific_args(parser)

    def __init__(self, n_fft, hop_length, num_frame,
                 spec_type, spec_est_mode,
                 conditional_spec2spec,
                 optimizer, lr,
                 train_loss, val_loss, hparams=None
                 ):
        super(Spectrogram_based, self).__init__(n_fft, hop_length, num_frame,
                                                optimizer, lr, hparams)

        self.n_fft = n_fft
        self.hop_length = hop_length
        self.num_frame = num_frame

        assert spec_type in ['magnitude', 'complex']
        assert spec_est_mode in ['masking', 'mapping']
        self.magnitude_based = spec_type == 'magnitude'
        self.masking_based = spec_est_mode == 'masking'
        self.stft = fourier.multi_channeled_STFT(n_fft=n_fft, hop_length=hop_length)
        self.stft.freeze()

        self.spec2spec = conditional_spec2spec
        self.valid_estimation_dict = {}
        self.val_loss = val_loss
        self.train_loss = train_loss

        self.init_weights()

    def init_weights(self):
        init_weights_functional(self.spec2spec,
                                self.spec2spec.activation)

    def training_step(self, batch, batch_idx):
        mixture_signal, target_signal, condition = batch
        loss = self.train_loss(self, mixture_signal, condition, target_signal)
        self.log('train_loss', loss, prog_bar=False, logger=True, on_step=False, on_epoch=True,
                 reduce_fx=torch.mean)
        return loss

    # Validation Process
    def on_validation_epoch_start(self):
        for target_name in self.target_names:
            self.valid_estimation_dict[target_name] = {mixture_idx: {}
                                                       for mixture_idx
                                                       in range(14)}

    def validation_step(self, batch, batch_idx):

        mixtures, targets, mixture_ids, window_offsets, input_conditions, target_names = batch

        loss = self.val_loss(self, mixtures, input_conditions, targets)

        self.log('raw_val_loss', loss, prog_bar=False, logger=False, reduce_fx=torch.mean)

        # Result Cache
        if 0 in mixture_ids.view(-1):
            estimated_targets = self.separate(mixtures, input_conditions)[:, self.trim_length:-self.trim_length]
            targets = targets[:, self.trim_length:-self.trim_length]

            for mixture, mixture_idx, window_offset, input_condition, target_name, estimated_target \
                    in zip(mixtures, mixture_ids, window_offsets, input_conditions, target_names, estimated_targets):

                if mixture_idx == 0:
                    self.valid_estimation_dict[target_name][mixture_idx.item()][
                        window_offset.item()] = estimated_target.detach().cpu().numpy()
        return loss

    def validation_epoch_end(self, outputs: List[Any]) -> None:
        for idx in [0]:
            estimation = {}
            for target_name in self.target_names:
                estimation[target_name] = get_estimation(idx, target_name, self.valid_estimation_dict)
                if estimation[target_name] is None:
                    continue
                if estimation[target_name] is not None:
                    estimation[target_name] = estimation[target_name].astype(np.float32)

                    if self.current_epoch > 1 and isinstance(self.logger, WandbLogger):
                        track = estimation[target_name]
                        if track.shape[0] > 40 * 44100:
                            track = track[44100 * 20:44100 * 40]

                        self.logger.experiment.log({'result_sample_{}_{}'.format(self.current_epoch, target_name): [
                            wandb.Audio(track, caption='{}_{}'.format(idx, target_name), sample_rate=44100)]})

        reduced_loss = torch.stack(outputs).mean()
        self.log('val_loss', reduced_loss, prog_bar=False, logger=True, on_step=False, on_epoch=True, sync_dist=True)
        print(reduced_loss)

    @abstractmethod
    def to_spec(self, input_signal) -> torch.Tensor:
        pass

    @abstractmethod
    def separate(self, input_signal, input_condition) -> torch.Tensor:
        pass

To Reproduce

My checkpoint_callback is

    checkpoint_callback = ModelCheckpoint(
        dirpath=ckpt_path,
        save_top_k=save_top_k,
        verbose=True,
        monitor = "val_loss",
        save_last= False,
        period  = args["check_val_every_n_epoch"],
        save_weights_only=args['save_weights_only']
    )

I use the command python main.py --gpus 4 --distributed_backend ddp for multiple-GPU running, while I use python main.py --gpus 1 for single GPU running. I did not change anything else.

Expected behavior

Model supposed to be saved smoothly, however, it is stuck at the step of saving the checkpoint. The GPU utilization shows 100% and never change. Please see the figures below:
GPU utilization stay at 100% forever
image

Saving is stuck at epoch 0
image

Environment

  • pytorch lightning version :1.2.1
  • PyTorch Version (e.g., 1.0): 1.4
  • OS (e.g., Linux): Ubuntu 18.04
  • How you installed PyTorch (conda, pip, source): pip
  • Build command you used (if compiling from source):
  • Python version:
  • CUDA/cuDNN version: 10.2
  • GPU models and configuration:
  • Any other relevant information:

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't workingcheckpointingRelated to checkpointingdistributedGeneric distributed-related topichelp wantedOpen to be worked onpriority: 0High priority taskwaiting on authorWaiting on user action, correction, or update

    Type

    No type

    Projects

    No projects

    Milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions