@@ -772,3 +772,39 @@ def test_trainer_config(trainer_kwargs, expected):
772772 assert trainer .on_gpu is expected ["on_gpu" ]
773773 assert trainer .single_gpu is expected ["single_gpu" ]
774774 assert trainer .num_processes == expected ["num_processes" ]
775+
776+
777+ def test_trainer_subclassing ():
778+
779+ model = EvalModelTemplate ()
780+
781+ # First way of pulling out args from signature is to list them
782+ class TrainerSubclass (Trainer ):
783+
784+ def __init__ (self , custom_arg , * args , custom_kwarg = 'test' , ** kwargs ):
785+ super ().__init__ (* args , ** kwargs )
786+ self .custom_arg = custom_arg
787+ self .custom_kwarg = custom_kwarg
788+
789+ trainer = TrainerSubclass (123 , custom_kwarg = 'custom' , fast_dev_run = True )
790+ result = trainer .fit (model )
791+ assert result == 1
792+ assert trainer .custom_arg == 123
793+ assert trainer .custom_kwarg == 'custom'
794+ assert trainer .fast_dev_run
795+
796+ # Second way is to pop from the dict
797+ # It's a special case because Trainer does not have any positional args
798+ class TrainerSubclass (Trainer ):
799+
800+ def __init__ (self , ** kwargs ):
801+ self .custom_arg = kwargs .pop ('custom_arg' , 0 )
802+ self .custom_kwarg = kwargs .pop ('custom_kwarg' , 'test' )
803+ super ().__init__ (** kwargs )
804+
805+ trainer = TrainerSubclass (custom_kwarg = 'custom' , fast_dev_run = True )
806+ result = trainer .fit (model )
807+ assert result == 1
808+ assert trainer .custom_kwarg == 'custom'
809+ assert trainer .fast_dev_run
810+
0 commit comments