Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
172 changes: 95 additions & 77 deletions pytorch_lightning/models/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,42 +129,13 @@ def __init__(self,
self.track_grad_norm = track_grad_norm
self.fast_dev_run = fast_dev_run
self.on_gpu = gpus is not None and torch.cuda.is_available()
self.experiment = experiment
self.exp_save_path = None
if self.experiment is not None:
self.exp_save_path = experiment.get_data_path(experiment.name, experiment.version)
self.cluster = cluster
self.process_position = process_position
self.current_gpu_name = current_gpu_name
self.print_weights_summary = print_weights_summary
self.checkpoint_callback = checkpoint_callback

if self.checkpoint_callback is not None:
self.checkpoint_callback.save_function = self.save_checkpoint

self.early_stop = early_stop_callback
self.model = None
self.max_nb_epochs = max_nb_epochs
if isinstance(accumulate_grad_batches, dict):
self.accumulation_scheduler = GradientAccumulationScheduler(accumulate_grad_batches)
elif isinstance(accumulate_grad_batches, int):
schedule = {1: accumulate_grad_batches}
self.accumulation_scheduler = GradientAccumulationScheduler(schedule)
else:
raise TypeError("Gradient accumulation supports only int and dict types")
self.early_stop_callback = early_stop_callback
self.min_nb_epochs = min_nb_epochs
self.nb_sanity_val_steps = nb_sanity_val_steps
self.lr_schedulers = []
self.amp_level = amp_level
self.print_nan_grads = print_nan_grads
self.data_parallel_device_ids = None
self.world_size = 1
self.node_rank = 0
self.use_ddp = False
self.use_dp = False
self.single_gpu = False
self.testing = False

# training bookeeping
self.total_batch_nb = 0
Expand All @@ -175,27 +146,114 @@ def __init__(self,
self.nb_val_batches = 0
self.nb_tng_batches = 0
self.nb_test_batches = 0
self.tng_dataloader = None
self.test_dataloader = None
self.val_dataloader = None

# training state
self.model = None
self.testing = False
self.lr_schedulers = []
self.optimizers = None
self.global_step = 0
self.current_epoch = 0
self.total_batches = 0

# configure callbacks
self.early_stop_callback = early_stop_callback
self.checkpoint_callback = checkpoint_callback
if self.checkpoint_callback is not None:
self.checkpoint_callback.save_function = self.save_checkpoint

# configure experiment
self.experiment = experiment
self.exp_save_path = None
if self.experiment is not None:
self.exp_save_path = experiment.get_data_path(experiment.name, experiment.version)

# accumulated grads
self.__configure_accumulated_gradients(accumulate_grad_batches)

# allow string and gpu list
self.data_parallel_device_ids = self.__parse_gpu_ids(gpus)

# distributed backend choice
self.use_ddp = False
self.use_dp = False
self.single_gpu = False
self.cluster = cluster
self.__set_distributed_mode(distributed_backend, nb_gpu_nodes)

# init flags for SLURM+ddp to work
self.proc_rank = 0
self.world_size = 1
self.node_rank = 0
self.__configure_slurm_ddp()

# can't init progress bar here because starting a new process
# means the prog_bar won't survive pickling
self.show_progress_bar = show_progress_bar

# logging
self.log_save_interval = log_save_interval
self.val_check_interval = val_check_interval
self.add_log_row_interval = add_log_row_interval

# how much of the data to use
self.__determine_data_use_amount(train_percent_check, val_percent_check,
test_percent_check, overfit_pct)

# 16 bit mixed precision training using apex
self.amp_level = amp_level
self.__init_amp(use_amp)

def __init_amp(self, use_amp):
self.use_amp = use_amp and APEX_AVAILABLE
if self.use_amp:
print('using 16bit precision')

if use_amp and not APEX_AVAILABLE: # pragma: no cover
msg = """
You set use_amp=True but do not have apex installed.
Install apex first using this guide and rerun with use_amp=True:
https://github.com/NVIDIA/apex#linux

this run will NOT use 16 bit precision
"""
raise ModuleNotFoundError(msg)

def __configure_accumulated_gradients(self, accumulate_grad_batches):
if isinstance(accumulate_grad_batches, dict):
self.accumulation_scheduler = GradientAccumulationScheduler(accumulate_grad_batches)
elif isinstance(accumulate_grad_batches, int):
schedule = {1: accumulate_grad_batches}
self.accumulation_scheduler = GradientAccumulationScheduler(schedule)
else:
raise TypeError("Gradient accumulation supports only int and dict types")

def __parse_gpu_ids(self, gpus):
# gpus come in as a string.
# if gpus = -1 then use all available devices
# otherwise, split the string using commas
if gpus is not None:
if type(gpus) is list:
self.data_parallel_device_ids = gpus
gpus = gpus
elif type(gpus) is str:
if gpus == '-1':
self.data_parallel_device_ids = list(range(0, torch.cuda.device_count()))
gpus = list(range(0, torch.cuda.device_count()))
else:
self.data_parallel_device_ids = [int(x.strip()) for x in gpus.split(',')]
gpus = [int(x.strip()) for x in gpus.split(',')]
else:
raise Exception('gpus has to be a string or list of ids')

# set the correct cuda visible devices (using pci order)
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = ','.join([str(x) for x in
self.data_parallel_device_ids])
os.environ["CUDA_VISIBLE_DEVICES"] = ','.join([str(x) for x in gpus])
print('VISIBLE GPUS: %r' % os.environ["CUDA_VISIBLE_DEVICES"])

return gpus

def __set_distributed_mode(self, distributed_backend, nb_gpu_nodes):
# make DP and DDP mutually exclusive
# single GPU will also use DP with devices=[0]
requested_gpus = self.data_parallel_device_ids is not None
Expand All @@ -218,6 +276,9 @@ def __init__(self,
self.use_dp = False
self.single_gpu = True

print('gpu available: {}, used: {}'.format(torch.cuda.is_available(), self.on_gpu))

def __configure_slurm_ddp(self):
# extract SLURM flag vars
# whenever we have the correct number of tasks, we let slurm manage processes
# otherwise we launch the required number of processes
Expand All @@ -231,49 +292,6 @@ def __init__(self,
# likely not on slurm, so set the slurm managed flag to false
self.is_slurm_managing_tasks = False

# process info
self.proc_rank = 0

# training state
self.optimizers = None
self.global_step = 0
self.current_epoch = 0
self.total_batches = 0

# can't init progress bar here because starting a new process
# means the prog_bar won't survive pickling
self.show_progress_bar = show_progress_bar

# logging
self.log_save_interval = log_save_interval
self.val_check_interval = val_check_interval
self.add_log_row_interval = add_log_row_interval

# dataloaders
self.tng_dataloader = None
self.test_dataloader = None
self.val_dataloader = None

# how much of the data to use
self.__determine_data_use_amount(train_percent_check, val_percent_check,
test_percent_check, overfit_pct)
print('gpu available: {}, used: {}'.format(torch.cuda.is_available(), self.on_gpu))

# 16 bit mixed precision training using apex
self.use_amp = use_amp and APEX_AVAILABLE
if self.use_amp:
print('using 16bit precision')

if use_amp and not APEX_AVAILABLE: # pragma: no cover
msg = """
You set use_amp=True but do not have apex installed.
Install apex first using this guide and rerun with use_amp=True:
https://github.com/NVIDIA/apex#linux

this run will NOT use 16 bit precision
"""
raise ModuleNotFoundError(msg)

def restore_state_if_existing_checkpoint(self):
# restore trainer state and model if there is a weight for this experiment
last_epoch = -1
Expand Down