Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 20 additions & 4 deletions tests/backends/ddp_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,20 +14,24 @@
"""
Runs either `.fit()` or `.test()` on a single node across multiple gpus.
"""
import os
from argparse import ArgumentParser

import tests as pl_tests
from pytorch_lightning import Trainer, seed_everything
from tests.base import EvalModelTemplate
import os

import torch


def main():
seed_everything(1234)

parser = ArgumentParser(add_help=False)
parser = Trainer.add_argparse_args(parser)
parser.add_argument('--trainer_method', default='fit')
parser.add_argument('--tmpdir')
parser.add_argument('--workdir')
parser.set_defaults(gpus=2)
parser.set_defaults(distributed_backend="ddp")
args = parser.parse_args()
Expand All @@ -38,14 +42,26 @@ def main():
result = {}
if args.trainer_method == 'fit':
trainer.fit(model)
result = {'status': 'complete', 'method': args.trainer_method, 'result': None}
result = {
'status': 'complete',
'method': args.trainer_method,
'result': None
}
if args.trainer_method == 'test':
result = trainer.test(model)
result = {'status': 'complete', 'method': args.trainer_method, 'result': result}
result = {
'status': 'complete',
'method': args.trainer_method,
'result': result
}
if args.trainer_method == 'fit_test':
trainer.fit(model)
result = trainer.test(model)
result = {'status': 'complete', 'method': args.trainer_method, 'result': result}
result = {
'status': 'complete',
'method': args.trainer_method,
'result': result
}

if len(result) > 0:
file_path = os.path.join(args.tmpdir, 'ddp.result')
Expand Down
4 changes: 1 addition & 3 deletions tests/utilities/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,10 @@ def call_training_script(module_file, cli_args, method, tmpdir, timeout=60):

# need to set the PYTHONPATH in case pytorch_lightning was not installed into the environment
env = os.environ.copy()
env['PYTHONPATH'] = f'{pytorch_lightning.__file__}:' + env.get('PYTHONPATH', '')
env['PYTHONPATH'] = env.get('PYTHONPATH', '') + f'{pytorch_lightning.__file__}:'

# for running in ddp mode, we need to lauch it's own process or pytest will get stuck
p = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, env=env)

try:
std, err = p.communicate(timeout=timeout)
err = str(err.decode("utf-8"))
Expand All @@ -42,5 +41,4 @@ def call_training_script(module_file, cli_args, method, tmpdir, timeout=60):
except TimeoutExpired:
p.kill()
std, err = p.communicate()

return std, err