1414"""
1515Runs either `.fit()` or `.test()` on a single node across multiple gpus.
1616"""
17+ import os
1718from argparse import ArgumentParser
1819
20+ import tests as pl_tests
1921from pytorch_lightning import Trainer , seed_everything
2022from tests .base import EvalModelTemplate
21- import os
23+
2224import torch
2325
2426
2527def main ():
2628 seed_everything (1234 )
29+
2730 parser = ArgumentParser (add_help = False )
2831 parser = Trainer .add_argparse_args (parser )
2932 parser .add_argument ('--trainer_method' , default = 'fit' )
3033 parser .add_argument ('--tmpdir' )
34+ parser .add_argument ('--workdir' )
3135 parser .set_defaults (gpus = 2 )
3236 parser .set_defaults (distributed_backend = "ddp" )
3337 args = parser .parse_args ()
@@ -38,14 +42,26 @@ def main():
3842 result = {}
3943 if args .trainer_method == 'fit' :
4044 trainer .fit (model )
41- result = {'status' : 'complete' , 'method' : args .trainer_method , 'result' : None }
45+ result = {
46+ 'status' : 'complete' ,
47+ 'method' : args .trainer_method ,
48+ 'result' : None
49+ }
4250 if args .trainer_method == 'test' :
4351 result = trainer .test (model )
44- result = {'status' : 'complete' , 'method' : args .trainer_method , 'result' : result }
52+ result = {
53+ 'status' : 'complete' ,
54+ 'method' : args .trainer_method ,
55+ 'result' : result
56+ }
4557 if args .trainer_method == 'fit_test' :
4658 trainer .fit (model )
4759 result = trainer .test (model )
48- result = {'status' : 'complete' , 'method' : args .trainer_method , 'result' : result }
60+ result = {
61+ 'status' : 'complete' ,
62+ 'method' : args .trainer_method ,
63+ 'result' : result
64+ }
4965
5066 if len (result ) > 0 :
5167 file_path = os .path .join (args .tmpdir , 'ddp.result' )
0 commit comments