From 62abecbdac1ab0ab8f5bd7897fcac66a7779913b Mon Sep 17 00:00:00 2001 From: William Falcon Date: Thu, 5 Sep 2019 23:25:04 -0400 Subject: [PATCH] refactored init --- pytorch_lightning/models/trainer.py | 172 +++++++++++++++------------- 1 file changed, 95 insertions(+), 77 deletions(-) diff --git a/pytorch_lightning/models/trainer.py b/pytorch_lightning/models/trainer.py index abb4b16719dbf..28921f2768fe6 100644 --- a/pytorch_lightning/models/trainer.py +++ b/pytorch_lightning/models/trainer.py @@ -129,42 +129,13 @@ def __init__(self, self.track_grad_norm = track_grad_norm self.fast_dev_run = fast_dev_run self.on_gpu = gpus is not None and torch.cuda.is_available() - self.experiment = experiment - self.exp_save_path = None - if self.experiment is not None: - self.exp_save_path = experiment.get_data_path(experiment.name, experiment.version) - self.cluster = cluster self.process_position = process_position self.current_gpu_name = current_gpu_name self.print_weights_summary = print_weights_summary - self.checkpoint_callback = checkpoint_callback - - if self.checkpoint_callback is not None: - self.checkpoint_callback.save_function = self.save_checkpoint - - self.early_stop = early_stop_callback - self.model = None self.max_nb_epochs = max_nb_epochs - if isinstance(accumulate_grad_batches, dict): - self.accumulation_scheduler = GradientAccumulationScheduler(accumulate_grad_batches) - elif isinstance(accumulate_grad_batches, int): - schedule = {1: accumulate_grad_batches} - self.accumulation_scheduler = GradientAccumulationScheduler(schedule) - else: - raise TypeError("Gradient accumulation supports only int and dict types") - self.early_stop_callback = early_stop_callback self.min_nb_epochs = min_nb_epochs self.nb_sanity_val_steps = nb_sanity_val_steps - self.lr_schedulers = [] - self.amp_level = amp_level self.print_nan_grads = print_nan_grads - self.data_parallel_device_ids = None - self.world_size = 1 - self.node_rank = 0 - self.use_ddp = False - self.use_dp = False - self.single_gpu = False - self.testing = False # training bookeeping self.total_batch_nb = 0 @@ -175,27 +146,114 @@ def __init__(self, self.nb_val_batches = 0 self.nb_tng_batches = 0 self.nb_test_batches = 0 + self.tng_dataloader = None + self.test_dataloader = None + self.val_dataloader = None + + # training state + self.model = None + self.testing = False + self.lr_schedulers = [] + self.optimizers = None + self.global_step = 0 + self.current_epoch = 0 + self.total_batches = 0 + + # configure callbacks + self.early_stop_callback = early_stop_callback + self.checkpoint_callback = checkpoint_callback + if self.checkpoint_callback is not None: + self.checkpoint_callback.save_function = self.save_checkpoint + + # configure experiment + self.experiment = experiment + self.exp_save_path = None + if self.experiment is not None: + self.exp_save_path = experiment.get_data_path(experiment.name, experiment.version) + + # accumulated grads + self.__configure_accumulated_gradients(accumulate_grad_batches) + + # allow string and gpu list + self.data_parallel_device_ids = self.__parse_gpu_ids(gpus) + + # distributed backend choice + self.use_ddp = False + self.use_dp = False + self.single_gpu = False + self.cluster = cluster + self.__set_distributed_mode(distributed_backend, nb_gpu_nodes) + + # init flags for SLURM+ddp to work + self.proc_rank = 0 + self.world_size = 1 + self.node_rank = 0 + self.__configure_slurm_ddp() + + # can't init progress bar here because starting a new process + # means the prog_bar won't survive pickling + self.show_progress_bar = show_progress_bar + + # logging + self.log_save_interval = log_save_interval + self.val_check_interval = val_check_interval + self.add_log_row_interval = add_log_row_interval + + # how much of the data to use + self.__determine_data_use_amount(train_percent_check, val_percent_check, + test_percent_check, overfit_pct) + + # 16 bit mixed precision training using apex + self.amp_level = amp_level + self.__init_amp(use_amp) + + def __init_amp(self, use_amp): + self.use_amp = use_amp and APEX_AVAILABLE + if self.use_amp: + print('using 16bit precision') + + if use_amp and not APEX_AVAILABLE: # pragma: no cover + msg = """ + You set use_amp=True but do not have apex installed. + Install apex first using this guide and rerun with use_amp=True: + https://github.com/NVIDIA/apex#linux + + this run will NOT use 16 bit precision + """ + raise ModuleNotFoundError(msg) + + def __configure_accumulated_gradients(self, accumulate_grad_batches): + if isinstance(accumulate_grad_batches, dict): + self.accumulation_scheduler = GradientAccumulationScheduler(accumulate_grad_batches) + elif isinstance(accumulate_grad_batches, int): + schedule = {1: accumulate_grad_batches} + self.accumulation_scheduler = GradientAccumulationScheduler(schedule) + else: + raise TypeError("Gradient accumulation supports only int and dict types") + def __parse_gpu_ids(self, gpus): # gpus come in as a string. # if gpus = -1 then use all available devices # otherwise, split the string using commas if gpus is not None: if type(gpus) is list: - self.data_parallel_device_ids = gpus + gpus = gpus elif type(gpus) is str: if gpus == '-1': - self.data_parallel_device_ids = list(range(0, torch.cuda.device_count())) + gpus = list(range(0, torch.cuda.device_count())) else: - self.data_parallel_device_ids = [int(x.strip()) for x in gpus.split(',')] + gpus = [int(x.strip()) for x in gpus.split(',')] else: raise Exception('gpus has to be a string or list of ids') # set the correct cuda visible devices (using pci order) os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" - os.environ["CUDA_VISIBLE_DEVICES"] = ','.join([str(x) for x in - self.data_parallel_device_ids]) + os.environ["CUDA_VISIBLE_DEVICES"] = ','.join([str(x) for x in gpus]) print('VISIBLE GPUS: %r' % os.environ["CUDA_VISIBLE_DEVICES"]) + return gpus + + def __set_distributed_mode(self, distributed_backend, nb_gpu_nodes): # make DP and DDP mutually exclusive # single GPU will also use DP with devices=[0] requested_gpus = self.data_parallel_device_ids is not None @@ -218,6 +276,9 @@ def __init__(self, self.use_dp = False self.single_gpu = True + print('gpu available: {}, used: {}'.format(torch.cuda.is_available(), self.on_gpu)) + + def __configure_slurm_ddp(self): # extract SLURM flag vars # whenever we have the correct number of tasks, we let slurm manage processes # otherwise we launch the required number of processes @@ -231,49 +292,6 @@ def __init__(self, # likely not on slurm, so set the slurm managed flag to false self.is_slurm_managing_tasks = False - # process info - self.proc_rank = 0 - - # training state - self.optimizers = None - self.global_step = 0 - self.current_epoch = 0 - self.total_batches = 0 - - # can't init progress bar here because starting a new process - # means the prog_bar won't survive pickling - self.show_progress_bar = show_progress_bar - - # logging - self.log_save_interval = log_save_interval - self.val_check_interval = val_check_interval - self.add_log_row_interval = add_log_row_interval - - # dataloaders - self.tng_dataloader = None - self.test_dataloader = None - self.val_dataloader = None - - # how much of the data to use - self.__determine_data_use_amount(train_percent_check, val_percent_check, - test_percent_check, overfit_pct) - print('gpu available: {}, used: {}'.format(torch.cuda.is_available(), self.on_gpu)) - - # 16 bit mixed precision training using apex - self.use_amp = use_amp and APEX_AVAILABLE - if self.use_amp: - print('using 16bit precision') - - if use_amp and not APEX_AVAILABLE: # pragma: no cover - msg = """ - You set use_amp=True but do not have apex installed. - Install apex first using this guide and rerun with use_amp=True: - https://github.com/NVIDIA/apex#linux - - this run will NOT use 16 bit precision - """ - raise ModuleNotFoundError(msg) - def restore_state_if_existing_checkpoint(self): # restore trainer state and model if there is a weight for this experiment last_epoch = -1