Skip to content

Commit 16fa4ed

Browse files
gianscarpeBordatchatonSeanNaren
authored
Fixed PYTHONPATH for ddp test model (#4528)
* Fixed PYTHONPATH for ddp test model * Removed debug calls * Apply suggestions from code review Co-authored-by: Jirka Borovec <[email protected]> Co-authored-by: chaton <[email protected]> Co-authored-by: Sean Naren <[email protected]>
1 parent 4bb3a08 commit 16fa4ed

File tree

2 files changed

+21
-7
lines changed

2 files changed

+21
-7
lines changed

tests/backends/ddp_model.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,20 +14,24 @@
1414
"""
1515
Runs either `.fit()` or `.test()` on a single node across multiple gpus.
1616
"""
17+
import os
1718
from argparse import ArgumentParser
1819

20+
import tests as pl_tests
1921
from pytorch_lightning import Trainer, seed_everything
2022
from tests.base import EvalModelTemplate
21-
import os
23+
2224
import torch
2325

2426

2527
def main():
2628
seed_everything(1234)
29+
2730
parser = ArgumentParser(add_help=False)
2831
parser = Trainer.add_argparse_args(parser)
2932
parser.add_argument('--trainer_method', default='fit')
3033
parser.add_argument('--tmpdir')
34+
parser.add_argument('--workdir')
3135
parser.set_defaults(gpus=2)
3236
parser.set_defaults(distributed_backend="ddp")
3337
args = parser.parse_args()
@@ -38,14 +42,26 @@ def main():
3842
result = {}
3943
if args.trainer_method == 'fit':
4044
trainer.fit(model)
41-
result = {'status': 'complete', 'method': args.trainer_method, 'result': None}
45+
result = {
46+
'status': 'complete',
47+
'method': args.trainer_method,
48+
'result': None
49+
}
4250
if args.trainer_method == 'test':
4351
result = trainer.test(model)
44-
result = {'status': 'complete', 'method': args.trainer_method, 'result': result}
52+
result = {
53+
'status': 'complete',
54+
'method': args.trainer_method,
55+
'result': result
56+
}
4557
if args.trainer_method == 'fit_test':
4658
trainer.fit(model)
4759
result = trainer.test(model)
48-
result = {'status': 'complete', 'method': args.trainer_method, 'result': result}
60+
result = {
61+
'status': 'complete',
62+
'method': args.trainer_method,
63+
'result': result
64+
}
4965

5066
if len(result) > 0:
5167
file_path = os.path.join(args.tmpdir, 'ddp.result')

tests/utilities/distributed.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,11 +29,10 @@ def call_training_script(module_file, cli_args, method, tmpdir, timeout=60):
2929

3030
# need to set the PYTHONPATH in case pytorch_lightning was not installed into the environment
3131
env = os.environ.copy()
32-
env['PYTHONPATH'] = f'{pytorch_lightning.__file__}:' + env.get('PYTHONPATH', '')
32+
env['PYTHONPATH'] = env.get('PYTHONPATH', '') + f'{pytorch_lightning.__file__}:'
3333

3434
# for running in ddp mode, we need to lauch it's own process or pytest will get stuck
3535
p = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, env=env)
36-
3736
try:
3837
std, err = p.communicate(timeout=timeout)
3938
err = str(err.decode("utf-8"))
@@ -42,5 +41,4 @@ def call_training_script(module_file, cli_args, method, tmpdir, timeout=60):
4241
except TimeoutExpired:
4342
p.kill()
4443
std, err = p.communicate()
45-
4644
return std, err

0 commit comments

Comments
 (0)