@@ -89,10 +89,11 @@ class MockingLossModule(nn.Module):
8989
9090def mocking_trainer (file = None , optimizer = _mocking_optim ) -> Trainer :
9191 trainer = Trainer (
92- collector = MockingCollector (),
93- total_frames = None ,
94- frame_skip = None ,
95- optim_steps_per_batch = None ,
92+ MockingCollector (),
93+ * [
94+ None ,
95+ ]
96+ * 2 ,
9697 loss_module = MockingLossModule (),
9798 optimizer = optimizer ,
9899 save_trainer_file = file ,
@@ -861,7 +862,7 @@ def test_recorder(self, N=8):
861862 with tempfile .TemporaryDirectory () as folder :
862863 logger = TensorboardLogger (exp_name = folder )
863864
864- environment = transformed_env_constructor (
865+ recorder = transformed_env_constructor (
865866 args ,
866867 video_tag = "tmp" ,
867868 norm_obs_only = True ,
@@ -873,7 +874,7 @@ def test_recorder(self, N=8):
873874 record_frames = args .record_frames ,
874875 frame_skip = args .frame_skip ,
875876 policy_exploration = None ,
876- environment = environment ,
877+ recorder = recorder ,
877878 record_interval = args .record_interval ,
878879 )
879880 trainer = mocking_trainer ()
@@ -935,7 +936,7 @@ def _make_recorder_and_trainer(tmpdirname):
935936 raise NotImplementedError
936937 trainer = mocking_trainer (file )
937938
938- environment = transformed_env_constructor (
939+ recorder = transformed_env_constructor (
939940 args ,
940941 video_tag = "tmp" ,
941942 norm_obs_only = True ,
@@ -947,7 +948,7 @@ def _make_recorder_and_trainer(tmpdirname):
947948 record_frames = args .record_frames ,
948949 frame_skip = args .frame_skip ,
949950 policy_exploration = None ,
950- environment = environment ,
951+ recorder = recorder ,
951952 record_interval = args .record_interval ,
952953 )
953954 recorder .register (trainer )
0 commit comments