Skip to content

Commit 6d55896

Browse files
Sean Narencarmoccaawaelchlipre-commit-ci[bot]
authored
[IPU] Allow poptorch.Options to override Trainer (#8233)
* Add test for poptorch Options * Hacks to get manual plugin support * Revert changes * Fix tests + ensure logic follow suit * Update pytorch_lightning/plugins/training_type/ipu.py Co-authored-by: Adrian Wälchli <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Cleaner * Cleaner Co-authored-by: Carlos Mocholí <[email protected]> Co-authored-by: Adrian Wälchli <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 5cef977 commit 6d55896

File tree

5 files changed

+95
-161
lines changed

5 files changed

+95
-161
lines changed

pytorch_lightning/plugins/training_type/ipu.py

Lines changed: 22 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin
2727
from pytorch_lightning.trainer.states import RunningStage
2828
from pytorch_lightning.trainer.supporters import CombinedLoader
29-
from pytorch_lightning.utilities import _POPTORCH_AVAILABLE, rank_zero_warn
29+
from pytorch_lightning.utilities import _POPTORCH_AVAILABLE
3030
from pytorch_lightning.utilities.apply_func import apply_to_collection
3131
from pytorch_lightning.utilities.cloud_io import get_filesystem
3232
from pytorch_lightning.utilities.exceptions import MisconfigurationException
@@ -129,10 +129,18 @@ def pre_dispatch(self) -> None:
129129
self._handle_gradient_accumulation_steps()
130130

131131
@property
132-
def replication_factor(self):
132+
def replication_factor(self) -> int:
133+
if not self.lightning_module:
134+
# The plugin has been passed in by the user and has not been connected to the Trainer.
135+
# Check if the user has passed in custom poptorch.Options to infer number of IPUs being used.
136+
# In this scenario we prioritize the training options.
137+
if self._training_opts:
138+
return self._training_opts.replication_factor
139+
if self._inference_opts:
140+
return self._inference_opts.replication_factor
133141
return len(self.parallel_devices)
134142

135-
def _create_opts(self, training: bool):
143+
def _create_opts(self, training: bool) -> 'poptorch.Options':
136144
opts = poptorch.Options()
137145
opts.deviceIterations(self.device_iterations)
138146
opts.replicationFactor(self.replication_factor)
@@ -147,71 +155,44 @@ def _create_opts(self, training: bool):
147155
def training_opts(self) -> 'poptorch.Options':
148156
if self._training_opts is None:
149157
self._training_opts = self._create_opts(training=True)
150-
self._validate_opts(self._training_opts, training=True)
151158
return self._training_opts
152159

153160
@property
154161
def inference_opts(self) -> 'poptorch.Options':
155162
if self._inference_opts is None:
156163
self._inference_opts = self._create_opts(training=False)
157-
self._validate_opts(self._inference_opts, training=False)
158164
return self._inference_opts
159165

160-
def _validate_opts(self, opts: 'poptorch.Options', training: bool) -> None:
161-
if opts is not None:
162-
if opts.replication_factor != self.replication_factor:
163-
rank_zero_warn(
164-
f"Manual poptorch.Options set replicationFactor to {opts.replication_factor} "
165-
f"which differs to the ipus={self.replication_factor} flag passed to the Trainer. "
166-
f"Setting to {self.replication_factor} in the poptorch.Options."
167-
)
168-
opts.set(replication_factor=self.replication_factor)
169-
if training:
170-
accumulate_grad_batches = self.accumulate_grad_batches
171-
if opts.Training.gradient_accumulation != accumulate_grad_batches:
172-
rank_zero_warn(
173-
f"Training poptorch.Options set gradientAccumulation to {opts.Training.gradient_accumulation}. "
174-
f"This is different to accumulate_grad_batches which was set to {accumulate_grad_batches}. "
175-
f"To change gradientAccumulation, please set accumulate_grad_batches in the Trainer. "
176-
f"Setting poptorch.Options gradientAccumulation to {accumulate_grad_batches}"
177-
)
178-
opts.Training.set(gradient_accumulation=accumulate_grad_batches)
179-
elif opts.Training.gradient_accumulation != 1:
180-
rank_zero_warn(
181-
"Inference poptorch.Options should set gradientAccumulation to 1. "
182-
"Setting gradientAccumulation to 1 for inference options."
183-
)
184-
opts.Training.set(gradient_accumulation=1)
185-
186166
@property
187167
def lightning_module(self) -> Optional['pl.LightningModule']:
188168
return self.model.module if isinstance(self.model, LightningIPUModule) else self.model
189169

190170
def on_reset_train_dataloader(self, dataloader: Union[Iterable, DataLoader]) -> Union[Iterable, DataLoader]:
191-
return self.process_dataloader(dataloader)
171+
return self._process_dataloader(dataloader, is_training=True)
192172

193173
def on_reset_val_dataloader(self, dataloader: Union[Iterable, DataLoader]) -> Union[Iterable, DataLoader]:
194-
return self.process_dataloader(dataloader)
174+
return self._process_dataloader(dataloader, is_training=False)
195175

196176
def on_reset_test_dataloader(self, dataloader: Union[Iterable, DataLoader]) -> Union[Iterable, DataLoader]:
197-
return self.process_dataloader(dataloader)
177+
return self._process_dataloader(dataloader, is_training=False)
198178

199179
def on_reset_predict_dataloader(self, dataloader: Union[Iterable, DataLoader]) -> Union[Iterable, DataLoader]:
200-
return self.process_dataloader(dataloader)
180+
return self._process_dataloader(dataloader, is_training=False)
201181

202-
def process_dataloader(self, dataloader: Union[Iterable, DataLoader]) -> Union[Iterable, DataLoader]:
182+
def _process_dataloader(
183+
self,
184+
dataloader: Union[Iterable, DataLoader],
185+
is_training: bool,
186+
) -> Union[Iterable, DataLoader]:
203187
if isinstance(dataloader, CombinedLoader):
204188
dataloader.loaders = apply_to_collection(
205-
dataloader.loaders,
206-
DataLoader,
207-
self.process_dataloader,
189+
dataloader.loaders, DataLoader, self._process_dataloader, is_training
208190
)
209191
return dataloader
210192
if isinstance(dataloader, list):
211-
dataloader = apply_to_collection(dataloader, DataLoader, self.process_dataloader)
193+
dataloader = apply_to_collection(dataloader, DataLoader, self._process_dataloader, is_training)
212194
return dataloader
213195
if not isinstance(dataloader, poptorch.DataLoader):
214-
is_training = self.lightning_module.trainer.training
215196
opts = self.training_opts if is_training else self.inference_opts
216197
dataloader = self._convert_to_poptorch_loader(dataloader=dataloader, opts=opts)
217198
return dataloader

pytorch_lightning/trainer/connectors/accelerator_connector.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -259,7 +259,7 @@ def on_tpu(self) -> bool:
259259

260260
@property
261261
def on_ipu(self) -> bool:
262-
return self.ipus is not None
262+
return self.ipus is not None or isinstance(self._training_type_plugin, IPUPlugin)
263263

264264
@property
265265
def tpu_id(self) -> Optional[int]:
@@ -327,6 +327,14 @@ def num_gpus(self) -> int:
327327
return 0
328328
return len(gpus)
329329

330+
@property
331+
def num_ipus(self) -> int:
332+
if isinstance(self.ipus, int):
333+
return self.ipus
334+
if isinstance(self._training_type_plugin, IPUPlugin):
335+
return self._training_type_plugin.replication_factor
336+
return 0
337+
330338
@property
331339
def parallel_devices(self) -> List[Union[torch.device, int]]:
332340
if self.on_gpu:
@@ -337,8 +345,7 @@ def parallel_devices(self) -> List[Union[torch.device, int]]:
337345
if isinstance(self.tpu_cores, int):
338346
devices = list(range(self.tpu_cores))
339347
elif self.on_ipu:
340-
if isinstance(self.ipus, int):
341-
devices = list(range(self.ipus))
348+
devices = list(range(self.num_ipus))
342349
else:
343350
devices = [torch.device("cpu")] * self.num_processes
344351
return devices

pytorch_lightning/trainer/properties.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@ def tpu_cores(self) -> int:
137137

138138
@property
139139
def ipus(self) -> int:
140-
return self.accelerator_connector.ipus
140+
return self.accelerator_connector.num_ipus
141141

142142
@property
143143
def num_gpus(self) -> int:

pytorch_lightning/trainer/trainer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
import torch
2424

2525
import pytorch_lightning as pl
26-
from pytorch_lightning.accelerators import Accelerator
26+
from pytorch_lightning.accelerators import Accelerator, IPUAccelerator
2727
from pytorch_lightning.callbacks import Callback
2828
from pytorch_lightning.core.datamodule import LightningDataModule
2929
from pytorch_lightning.core.memory import ModelSummary
@@ -1209,7 +1209,7 @@ def _log_device_info(self) -> None:
12091209
" `Trainer(tpu_cores=8)` or script `--tpu_cores=8`."
12101210
)
12111211

1212-
if _IPU_AVAILABLE and self._device_type != DeviceType.IPU:
1212+
if _IPU_AVAILABLE and self._device_type != DeviceType.IPU and not isinstance(self.accelerator, IPUAccelerator):
12131213
rank_zero_warn(
12141214
"IPU available but not used. Set the `ipus` flag in your trainer"
12151215
" `Trainer(ipus=8)` or script `--ipus=8`."

tests/accelerators/test_ipu.py

Lines changed: 60 additions & 114 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from pytorch_lightning.core.lightning import LightningModule
2424
from pytorch_lightning.plugins import IPUPlugin, IPUPrecisionPlugin
2525
from pytorch_lightning.trainer.states import RunningStage
26+
from pytorch_lightning.trainer.supporters import CombinedLoader
2627
from pytorch_lightning.utilities import _IPU_AVAILABLE
2728
from pytorch_lightning.utilities.exceptions import MisconfigurationException
2829
from tests.helpers.boring_model import BoringModel
@@ -112,6 +113,19 @@ def test_accelerator_selected(tmpdir):
112113
assert isinstance(trainer.accelerator, IPUAccelerator)
113114

114115

116+
@RunIf(ipu=True)
117+
def test_warning_if_ipus_not_used(tmpdir):
118+
with pytest.warns(UserWarning, match="IPU available but not used. Set the `ipus` flag in your trainer"):
119+
Trainer(default_root_dir=tmpdir)
120+
121+
122+
@RunIf(ipu=True)
123+
def test_no_warning_plugin(tmpdir):
124+
with pytest.warns(None) as record:
125+
Trainer(default_root_dir=tmpdir, plugins=IPUPlugin(training_opts=poptorch.Options()))
126+
assert len(record) == 0
127+
128+
115129
@RunIf(ipu=True)
116130
@pytest.mark.parametrize('ipus', [1, 4])
117131
def test_all_stages(tmpdir, ipus):
@@ -364,140 +378,72 @@ def test_manual_poptorch_opts(tmpdir):
364378

365379

366380
@RunIf(ipu=True)
367-
def test_manual_poptorch_opts_ipu_count(tmpdir):
368-
"""
369-
Ensure if the user passes manual poptorch Options
370-
and the number of ipus do not match, we warn and we set it for the user.
371-
"""
372-
373-
manual_ipus = 1
374-
expected_ipus = 2
375-
model = IPUModel()
376-
inference_opts = poptorch.Options()
377-
inference_opts.replicationFactor(manual_ipus)
378-
379-
training_opts = poptorch.Options()
380-
training_opts.replicationFactor(manual_ipus)
381-
382-
trainer = Trainer(
383-
default_root_dir=tmpdir,
384-
ipus=expected_ipus,
385-
fast_dev_run=True,
386-
plugins=IPUPlugin(inference_opts=inference_opts, training_opts=training_opts)
387-
)
388-
with pytest.warns(
389-
UserWarning,
390-
match=f"Manual poptorch.Options set replicationFactor to {manual_ipus} "
391-
f"which differs to the ipus={expected_ipus} flag passed to the Trainer. "
392-
f"Setting to {expected_ipus} in the poptorch.Options."
393-
):
394-
trainer.fit(model)
395-
assert isinstance(trainer.accelerator.training_type_plugin, IPUPlugin)
396-
assert trainer.accelerator.training_type_plugin.training_opts.replication_factor == 2
397-
assert trainer.accelerator.training_type_plugin.inference_opts.replication_factor == 2
398-
399-
400-
@RunIf(ipu=True)
401-
def test_manual_poptorch_opts_inference_grad_accum(tmpdir):
402-
"""
403-
Ensure if the user passes manual poptorch Options
404-
and grad accumulation is set greater than 1 for inference, we warn and set to 1.
405-
"""
406-
407-
model = IPUModel()
408-
inference_opts = poptorch.Options()
409-
inference_opts.Training.gradientAccumulation(4)
410-
411-
training_opts = poptorch.Options()
412-
training_opts.Training.gradientAccumulation(1)
413-
414-
trainer = Trainer(
415-
default_root_dir=tmpdir,
416-
ipus=1,
417-
fast_dev_run=True,
418-
plugins=IPUPlugin(inference_opts=inference_opts, training_opts=training_opts)
419-
)
420-
with pytest.warns(
421-
UserWarning,
422-
match="Inference poptorch.Options should set gradientAccumulation to 1. "
423-
"Setting gradientAccumulation to 1 for inference options.",
424-
):
425-
trainer.fit(model)
426-
assert isinstance(trainer.accelerator.training_type_plugin, IPUPlugin)
427-
assert trainer.accelerator.training_type_plugin.inference_opts.Training.gradient_accumulation == 1
428-
429-
430-
@RunIf(ipu=True)
431-
def test_manual_poptorch_opts_train_grad_accum(tmpdir):
381+
def test_manual_poptorch_opts_custom(tmpdir):
432382
"""
433-
Ensure if the user passes manual poptorch Options
434-
and grad accumulation differs to accumulate_grad_batches, we
383+
Ensure if the user passes manual poptorch Options with custom parameters set,
384+
we respect them in our poptorch options and the dataloaders.
435385
"""
436386

437387
model = IPUModel()
438-
inference_opts = poptorch.Options()
439-
inference_opts.Training.gradientAccumulation(1)
440-
441388
training_opts = poptorch.Options()
389+
training_opts.deviceIterations(8)
390+
training_opts.replicationFactor(2)
442391
training_opts.Training.gradientAccumulation(2)
443392

444-
trainer = Trainer(
445-
default_root_dir=tmpdir,
446-
ipus=1,
447-
fast_dev_run=True,
448-
accumulate_grad_batches=1,
449-
plugins=IPUPlugin(inference_opts=inference_opts, training_opts=training_opts)
450-
)
451-
with pytest.warns(
452-
UserWarning,
453-
match=f"Training poptorch.Options set gradientAccumulation to {2}. "
454-
f"This is different to accumulate_grad_batches which was set to {1}. "
455-
f"To change gradientAccumulation, please set accumulate_grad_batches in the Trainer. "
456-
f"Setting poptorch.Options gradientAccumulation to {1}",
457-
):
458-
trainer.fit(model)
459-
assert isinstance(trainer.accelerator.training_type_plugin, IPUPlugin)
460-
assert trainer.accelerator.training_type_plugin.inference_opts.Training.gradient_accumulation == 1
461-
462-
463-
@RunIf(ipu=True)
464-
def test_manual_poptorch_opts_custom(tmpdir):
465-
"""
466-
Ensure if the user passes manual poptorch Options with custom parameters set,
467-
we respect them in our poptorch options.
468-
"""
469-
470-
model = IPUModel()
471393
inference_opts = poptorch.Options()
472394
inference_opts.deviceIterations(16)
473-
inference_opts.replicationFactor(2)
395+
inference_opts.replicationFactor(1)
474396
inference_opts.Training.gradientAccumulation(1)
475397

476-
training_opts = poptorch.Options()
477-
training_opts.deviceIterations(8)
478-
training_opts.replicationFactor(2)
479-
training_opts.Training.gradientAccumulation(2)
398+
class TestCallback(Callback):
480399

481-
trainer = Trainer(
482-
default_root_dir=tmpdir,
483-
ipus=2,
484-
fast_dev_run=True,
485-
accumulate_grad_batches=2,
486-
plugins=IPUPlugin(inference_opts=inference_opts, training_opts=training_opts)
487-
)
400+
def on_fit_end(self, trainer: Trainer, pl_module: LightningModule) -> None:
401+
# ensure dataloaders were correctly set up during training.
402+
plugin = trainer.accelerator.training_type_plugin
403+
assert isinstance(plugin, IPUPlugin)
404+
assert plugin.training_opts.replication_factor == 2
405+
assert plugin.inference_opts.replication_factor == 1
406+
407+
val_dataloader = trainer.val_dataloaders[0]
408+
train_dataloader = trainer.train_dataloader
409+
assert isinstance(train_dataloader, CombinedLoader)
410+
train_dataloader = train_dataloader.loaders
411+
assert isinstance(val_dataloader, poptorch.DataLoader)
412+
assert isinstance(train_dataloader, poptorch.DataLoader)
413+
assert train_dataloader.options.replication_factor == 2
414+
assert val_dataloader.options.replication_factor == 1
415+
416+
plugin = IPUPlugin(inference_opts=inference_opts, training_opts=training_opts)
417+
# ensure we default to the training options replication factor
418+
assert plugin.replication_factor == 2
419+
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True, plugins=plugin, callbacks=TestCallback())
488420
trainer.fit(model)
421+
489422
plugin = trainer.accelerator.training_type_plugin
490423
assert isinstance(plugin, IPUPlugin)
491-
inference_opts = plugin.inference_opts
492-
training_opts = plugin.training_opts
493-
assert inference_opts.device_iterations == 16
494-
assert inference_opts.replication_factor == 2
495-
assert inference_opts.Training.gradient_accumulation == 1
496424

425+
training_opts = plugin.training_opts
497426
assert training_opts.device_iterations == 8
498427
assert training_opts.replication_factor == 2
499428
assert training_opts.Training.gradient_accumulation == 2
500429

430+
inference_opts = plugin.inference_opts
431+
assert inference_opts.device_iterations == 16
432+
assert inference_opts.replication_factor == 1
433+
assert inference_opts.Training.gradient_accumulation == 1
434+
435+
436+
@RunIf(ipu=True)
437+
def test_replication_factor(tmpdir):
438+
"""
439+
Ensure if the user passes manual poptorch Options with custom parameters set,
440+
we set them correctly in the dataloaders.
441+
"""
442+
443+
plugin = IPUPlugin()
444+
trainer = Trainer(ipus=2, default_root_dir=tmpdir, fast_dev_run=True, plugins=plugin)
445+
assert trainer.ipus == 2
446+
501447

502448
@RunIf(ipu=True)
503449
def test_default_opts(tmpdir):

0 commit comments

Comments
 (0)