Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 29 additions & 18 deletions pytorch_lightning/plugins/training_type/ipu.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,6 @@ def __init__(
os.environ["POPLAR_ENGINE_OPTIONS"] = json.dumps(options)

def pre_dispatch(self) -> None:
self._handle_gradient_accumulation_steps()
precision = self.lightning_module.trainer.precision
model = LightningIPUModule(self.lightning_module, precision)
self.model = model
Expand All @@ -127,6 +126,7 @@ def pre_dispatch(self) -> None:
options=self.inference_opts,
)
self.poptorch_models[x] = model
self._handle_gradient_accumulation_steps()

@property
def replication_factor(self):
Expand All @@ -136,7 +136,7 @@ def _create_opts(self, training: bool):
opts = poptorch.Options()
opts.deviceIterations(self.device_iterations)
opts.replicationFactor(self.replication_factor)
gradient_accumulation = self.lightning_module.trainer.accumulate_grad_batches if training else 1
gradient_accumulation = self.accumulate_grad_batches if training else 1
opts.Training.gradientAccumulation(gradient_accumulation)

if os.environ.get("PL_GLOBAL_SEED"):
Expand Down Expand Up @@ -167,7 +167,7 @@ def _validate_opts(self, opts: 'poptorch.Options', training: bool) -> None:
)
opts.set(replication_factor=self.replication_factor)
if training:
accumulate_grad_batches = self.lightning_module.trainer.accumulate_grad_batches
accumulate_grad_batches = self.accumulate_grad_batches
if opts.Training.gradient_accumulation != accumulate_grad_batches:
rank_zero_warn(
f"Training poptorch.Options set gradientAccumulation to {opts.Training.gradient_accumulation}. "
Expand Down Expand Up @@ -211,9 +211,9 @@ def process_dataloader(self, dataloader: Union[Iterable, DataLoader]) -> Union[I
dataloader = apply_to_collection(dataloader, DataLoader, self.process_dataloader)
return dataloader
if not isinstance(dataloader, poptorch.DataLoader):
dataloader = self._convert_to_poptorch_loader(
dataloader=dataloader, opts=self._create_opts(training=self.lightning_module.training)
)
is_training = self.lightning_module.trainer.training
opts = self.training_opts if is_training else self.inference_opts
dataloader = self._convert_to_poptorch_loader(dataloader=dataloader, opts=opts)
return dataloader

def _convert_to_poptorch_loader(self, dataloader: Union[Iterable, DataLoader],
Expand Down Expand Up @@ -242,33 +242,44 @@ def _convert_to_poptorch_loader(self, dataloader: Union[Iterable, DataLoader],
dataloader.multiprocessing_context = multiprocessing_context
return dataloader

@property
def accumulate_grad_batches(self) -> int:
"""
Tracks lazily the set accumulate_grad_batches in the trainer.
The IPUPlugin replaces the original accumulate_grad_batches.
"""
if self._original_accumulate_grad_batches is None:
self._original_accumulate_grad_batches = self.lightning_module.trainer.accumulate_grad_batches
if not isinstance(self._original_accumulate_grad_batches, int):
raise MisconfigurationException(
f"IPUs currently only support accumulate_grad_batches being an integer value. "
f"Received {self.accumulate_grad_batches}"
)
return self._original_accumulate_grad_batches

def _handle_gradient_accumulation_steps(self):
"""
This functions overrides the trainer.accumulation_scheduler to generate
``accumulate_grad_batches=1``.
Therefore, ``optimizer_step`` will be called on every batch, and the IPU will handle grad accumulation.
"""
self._original_accumulate_grad_batches = self.lightning_module.trainer.accumulate_grad_batches
if not isinstance(self._original_accumulate_grad_batches, int):
raise MisconfigurationException(
f"IPUs currently only support accumulate_grad_batches being an integer value. "
f"Received {self._original_accumulate_grad_batches}"
)
if self._original_accumulate_grad_batches > 1:
if self.accumulate_grad_batches > 1:
self.lightning_module.trainer.accumulation_scheduler = GradientAccumulationScheduler({0: 1})

def update_global_step(self, total_batch_idx: int, current_global_step: int) -> int:
if self._original_accumulate_grad_batches > 1:
if total_batch_idx % self._original_accumulate_grad_batches == 0:
if self.accumulate_grad_batches > 1:
if total_batch_idx % self.accumulate_grad_batches == 0:
current_global_step += 1
return current_global_step
return super().update_global_step(total_batch_idx, current_global_step)

@property
def _n_replicate(self):
# Ensure we replicate values to have enough dimensions to split across devices
accumulate_grad_batches = self._original_accumulate_grad_batches
return self.replication_factor * self.device_iterations * accumulate_grad_batches
opts = self.training_opts if self.lightning_module.training else self.inference_opts
accumulate_grad_batches = opts.Training.gradient_accumulation
device_iterations = opts.device_iterations
replication_factor = opts.replication_factor
return replication_factor * device_iterations * accumulate_grad_batches

def _prepare_input(self, args: Any):

Expand Down
39 changes: 39 additions & 0 deletions tests/accelerators/test_ipu.py
Original file line number Diff line number Diff line change
Expand Up @@ -460,6 +460,45 @@ def test_manual_poptorch_opts_train_grad_accum(tmpdir):
assert trainer.accelerator.training_type_plugin.inference_opts.Training.gradient_accumulation == 1


@RunIf(ipu=True)
def test_manual_poptorch_opts_custom(tmpdir):
"""
Ensure if the user passes manual poptorch Options with custom parameters set,
we respect them in our poptorch options.
"""

model = IPUModel()
inference_opts = poptorch.Options()
inference_opts.deviceIterations(16)
inference_opts.replicationFactor(2)
inference_opts.Training.gradientAccumulation(1)

training_opts = poptorch.Options()
training_opts.deviceIterations(8)
training_opts.replicationFactor(2)
training_opts.Training.gradientAccumulation(2)

trainer = Trainer(
default_root_dir=tmpdir,
ipus=2,
fast_dev_run=True,
accumulate_grad_batches=2,
plugins=IPUPlugin(inference_opts=inference_opts, training_opts=training_opts)
)
trainer.fit(model)
plugin = trainer.accelerator.training_type_plugin
assert isinstance(plugin, IPUPlugin)
inference_opts = plugin.inference_opts
training_opts = plugin.training_opts
assert inference_opts.device_iterations == 16
assert inference_opts.replication_factor == 2
assert inference_opts.Training.gradient_accumulation == 1

assert training_opts.device_iterations == 8
assert training_opts.replication_factor == 2
assert training_opts.Training.gradient_accumulation == 2


@RunIf(ipu=True)
def test_default_opts(tmpdir):
"""
Expand Down