diff --git a/pytorch_lightning/plugins/training_type/ipu.py b/pytorch_lightning/plugins/training_type/ipu.py index 9e7ef653f3814..a2e408b66f98f 100644 --- a/pytorch_lightning/plugins/training_type/ipu.py +++ b/pytorch_lightning/plugins/training_type/ipu.py @@ -108,7 +108,6 @@ def __init__( os.environ["POPLAR_ENGINE_OPTIONS"] = json.dumps(options) def pre_dispatch(self) -> None: - self._handle_gradient_accumulation_steps() precision = self.lightning_module.trainer.precision model = LightningIPUModule(self.lightning_module, precision) self.model = model @@ -127,6 +126,7 @@ def pre_dispatch(self) -> None: options=self.inference_opts, ) self.poptorch_models[x] = model + self._handle_gradient_accumulation_steps() @property def replication_factor(self): @@ -136,7 +136,7 @@ def _create_opts(self, training: bool): opts = poptorch.Options() opts.deviceIterations(self.device_iterations) opts.replicationFactor(self.replication_factor) - gradient_accumulation = self.lightning_module.trainer.accumulate_grad_batches if training else 1 + gradient_accumulation = self.accumulate_grad_batches if training else 1 opts.Training.gradientAccumulation(gradient_accumulation) if os.environ.get("PL_GLOBAL_SEED"): @@ -167,7 +167,7 @@ def _validate_opts(self, opts: 'poptorch.Options', training: bool) -> None: ) opts.set(replication_factor=self.replication_factor) if training: - accumulate_grad_batches = self.lightning_module.trainer.accumulate_grad_batches + accumulate_grad_batches = self.accumulate_grad_batches if opts.Training.gradient_accumulation != accumulate_grad_batches: rank_zero_warn( f"Training poptorch.Options set gradientAccumulation to {opts.Training.gradient_accumulation}. " @@ -211,9 +211,9 @@ def process_dataloader(self, dataloader: Union[Iterable, DataLoader]) -> Union[I dataloader = apply_to_collection(dataloader, DataLoader, self.process_dataloader) return dataloader if not isinstance(dataloader, poptorch.DataLoader): - dataloader = self._convert_to_poptorch_loader( - dataloader=dataloader, opts=self._create_opts(training=self.lightning_module.training) - ) + is_training = self.lightning_module.trainer.training + opts = self.training_opts if is_training else self.inference_opts + dataloader = self._convert_to_poptorch_loader(dataloader=dataloader, opts=opts) return dataloader def _convert_to_poptorch_loader(self, dataloader: Union[Iterable, DataLoader], @@ -242,33 +242,44 @@ def _convert_to_poptorch_loader(self, dataloader: Union[Iterable, DataLoader], dataloader.multiprocessing_context = multiprocessing_context return dataloader + @property + def accumulate_grad_batches(self) -> int: + """ + Tracks lazily the set accumulate_grad_batches in the trainer. + The IPUPlugin replaces the original accumulate_grad_batches. + """ + if self._original_accumulate_grad_batches is None: + self._original_accumulate_grad_batches = self.lightning_module.trainer.accumulate_grad_batches + if not isinstance(self._original_accumulate_grad_batches, int): + raise MisconfigurationException( + f"IPUs currently only support accumulate_grad_batches being an integer value. " + f"Received {self.accumulate_grad_batches}" + ) + return self._original_accumulate_grad_batches + def _handle_gradient_accumulation_steps(self): """ This functions overrides the trainer.accumulation_scheduler to generate ``accumulate_grad_batches=1``. Therefore, ``optimizer_step`` will be called on every batch, and the IPU will handle grad accumulation. """ - self._original_accumulate_grad_batches = self.lightning_module.trainer.accumulate_grad_batches - if not isinstance(self._original_accumulate_grad_batches, int): - raise MisconfigurationException( - f"IPUs currently only support accumulate_grad_batches being an integer value. " - f"Received {self._original_accumulate_grad_batches}" - ) - if self._original_accumulate_grad_batches > 1: + if self.accumulate_grad_batches > 1: self.lightning_module.trainer.accumulation_scheduler = GradientAccumulationScheduler({0: 1}) def update_global_step(self, total_batch_idx: int, current_global_step: int) -> int: - if self._original_accumulate_grad_batches > 1: - if total_batch_idx % self._original_accumulate_grad_batches == 0: + if self.accumulate_grad_batches > 1: + if total_batch_idx % self.accumulate_grad_batches == 0: current_global_step += 1 return current_global_step return super().update_global_step(total_batch_idx, current_global_step) @property def _n_replicate(self): - # Ensure we replicate values to have enough dimensions to split across devices - accumulate_grad_batches = self._original_accumulate_grad_batches - return self.replication_factor * self.device_iterations * accumulate_grad_batches + opts = self.training_opts if self.lightning_module.training else self.inference_opts + accumulate_grad_batches = opts.Training.gradient_accumulation + device_iterations = opts.device_iterations + replication_factor = opts.replication_factor + return replication_factor * device_iterations * accumulate_grad_batches def _prepare_input(self, args: Any): diff --git a/tests/accelerators/test_ipu.py b/tests/accelerators/test_ipu.py index eb5a5349483bd..363648c9f681d 100644 --- a/tests/accelerators/test_ipu.py +++ b/tests/accelerators/test_ipu.py @@ -460,6 +460,45 @@ def test_manual_poptorch_opts_train_grad_accum(tmpdir): assert trainer.accelerator.training_type_plugin.inference_opts.Training.gradient_accumulation == 1 +@RunIf(ipu=True) +def test_manual_poptorch_opts_custom(tmpdir): + """ + Ensure if the user passes manual poptorch Options with custom parameters set, + we respect them in our poptorch options. + """ + + model = IPUModel() + inference_opts = poptorch.Options() + inference_opts.deviceIterations(16) + inference_opts.replicationFactor(2) + inference_opts.Training.gradientAccumulation(1) + + training_opts = poptorch.Options() + training_opts.deviceIterations(8) + training_opts.replicationFactor(2) + training_opts.Training.gradientAccumulation(2) + + trainer = Trainer( + default_root_dir=tmpdir, + ipus=2, + fast_dev_run=True, + accumulate_grad_batches=2, + plugins=IPUPlugin(inference_opts=inference_opts, training_opts=training_opts) + ) + trainer.fit(model) + plugin = trainer.accelerator.training_type_plugin + assert isinstance(plugin, IPUPlugin) + inference_opts = plugin.inference_opts + training_opts = plugin.training_opts + assert inference_opts.device_iterations == 16 + assert inference_opts.replication_factor == 2 + assert inference_opts.Training.gradient_accumulation == 1 + + assert training_opts.device_iterations == 8 + assert training_opts.replication_factor == 2 + assert training_opts.Training.gradient_accumulation == 2 + + @RunIf(ipu=True) def test_default_opts(tmpdir): """