@@ -737,13 +737,15 @@ def _create_seq_mock_data_td3(
737737
738738 @pytest .mark .skipif (not _has_functorch , reason = "functorch not installed" )
739739 @pytest .mark .parametrize ("device" , get_available_devices ())
740- @pytest .mark .parametrize ("delay_actor, delay_value" , [(False , False ), (True , True )])
740+ @pytest .mark .parametrize (
741+ "delay_actor, delay_qvalue" , [(False , False ), (True , True )]
742+ )
741743 @pytest .mark .parametrize ("policy_noise" , [0.1 , 1.0 ])
742744 @pytest .mark .parametrize ("noise_clip" , [0.1 , 1.0 ])
743745 def test_td3 (
744746 self ,
745747 delay_actor ,
746- delay_value ,
748+ delay_qvalue ,
747749 device ,
748750 policy_noise ,
749751 noise_clip ,
@@ -760,11 +762,19 @@ def test_td3(
760762 policy_noise = policy_noise ,
761763 noise_clip = noise_clip ,
762764 delay_actor = delay_actor ,
763- delay_value = delay_value ,
765+ delay_qvalue = delay_qvalue ,
764766 )
765767 with _check_td_steady (td ):
766768 loss = loss_fn (td )
767769
770+ assert all (
771+ (p .grad is None ) or (p .grad == 0 ).all ()
772+ for p in loss_fn .qvalue_network_params .values (True , True )
773+ )
774+ assert all (
775+ (p .grad is None ) or (p .grad == 0 ).all ()
776+ for p in loss_fn .actor_network_params .values (True , True )
777+ )
768778 # check that losses are independent
769779 for k in loss .keys ():
770780 if not k .startswith ("loss" ):
@@ -773,71 +783,43 @@ def test_td3(
773783 if k == "loss_actor" :
774784 assert all (
775785 (p .grad is None ) or (p .grad == 0 ).all ()
776- for p in loss_fn .value_network_params
786+ for p in loss_fn .qvalue_network_params . values ( True , True )
777787 )
778788 assert not any (
779789 (p .grad is None ) or (p .grad == 0 ).all ()
780- for p in loss_fn .actor_network_params
790+ for p in loss_fn .actor_network_params . values ( True , True )
781791 )
782792 elif k == "loss_qvalue" :
783793 assert all (
784794 (p .grad is None ) or (p .grad == 0 ).all ()
785- for p in loss_fn .actor_network_params
795+ for p in loss_fn .actor_network_params . values ( True , True )
786796 )
787797 assert not any (
788798 (p .grad is None ) or (p .grad == 0 ).all ()
789- for p in loss_fn .value_network_params
799+ for p in loss_fn .qvalue_network_params . values ( True , True )
790800 )
791801 else :
792802 raise NotImplementedError (k )
793803 loss_fn .zero_grad ()
794804
795- # check overall grad
796805 sum ([item for _ , item in loss .items ()]).backward ()
797- parameters = list (actor .parameters ()) + list (value .parameters ())
798- for p in parameters :
799- assert p .grad .norm () > 0.0
806+ named_parameters = list (loss_fn .named_parameters ())
807+ named_buffers = list (loss_fn .named_buffers ())
800808
801- # Check param update effect on targets
802- target_actor = [p .clone () for p in loss_fn .target_actor_network_params ]
803- target_value = [p .clone () for p in loss_fn .target_value_network_params ]
804- for p in loss_fn .parameters ():
805- p .data += torch .randn_like (p )
806- target_actor2 = [p .clone () for p in loss_fn .target_actor_network_params ]
807- target_value2 = [p .clone () for p in loss_fn .target_value_network_params ]
808- if loss_fn .delay_actor :
809- assert all ((p1 == p2 ).all () for p1 , p2 in zip (target_actor , target_actor2 ))
810- else :
811- assert not any (
812- (p1 == p2 ).any () for p1 , p2 in zip (target_actor , target_actor2 )
813- )
814- if loss_fn .delay_value :
815- assert all ((p1 == p2 ).all () for p1 , p2 in zip (target_value , target_value2 ))
816- else :
817- assert not any (
818- (p1 == p2 ).any () for p1 , p2 in zip (target_value , target_value2 )
819- )
809+ assert len ({p for n , p in named_parameters }) == len (list (named_parameters ))
810+ assert len ({p for n , p in named_buffers }) == len (list (named_buffers ))
820811
821- # check that policy is updated after parameter update
822- parameters = [p .clone () for p in actor .parameters ()]
823- for p in loss_fn .parameters ():
824- p .data += torch .randn_like (p )
825- assert all ((p1 != p2 ).all () for p1 , p2 in zip (parameters , actor .parameters ()))
812+ for name , p in named_parameters :
813+ assert p .grad .norm () > 0.0 , f"parameter { name } has a null gradient"
826814
827815 @pytest .mark .skipif (not _has_functorch , reason = "functorch not installed" )
828816 @pytest .mark .parametrize ("n" , list (range (4 )))
829817 @pytest .mark .parametrize ("device" , get_available_devices ())
830- @pytest .mark .parametrize ("delay_actor,delay_value " , [(False , False ), (True , True )])
818+ @pytest .mark .parametrize ("delay_actor,delay_qvalue " , [(False , False ), (True , True )])
831819 @pytest .mark .parametrize ("policy_noise" , [0.1 , 1.0 ])
832820 @pytest .mark .parametrize ("noise_clip" , [0.1 , 1.0 ])
833821 def test_td3_batcher (
834- self ,
835- n ,
836- delay_actor ,
837- delay_value ,
838- device ,
839- policy_noise ,
840- noise_clip ,
822+ self , n , delay_actor , delay_qvalue , device , policy_noise , noise_clip , gamma = 0.9
841823 ):
842824 torch .manual_seed (self .seed )
843825 actor = self ._create_mock_actor (device = device )
@@ -847,18 +829,27 @@ def test_td3_batcher(
847829 actor ,
848830 value ,
849831 gamma = 0.9 ,
850- loss_function = "l2" ,
851832 policy_noise = policy_noise ,
852833 noise_clip = noise_clip ,
834+ delay_qvalue = delay_qvalue ,
853835 delay_actor = delay_actor ,
854- delay_value = delay_value ,
855836 )
856837
857- ms = MultiStep (gamma = 0.9 , n_steps_max = n ).to (device )
858- ms_td = ms (td .clone ())
838+ ms = MultiStep (gamma = gamma , n_steps_max = n ).to (device )
839+
840+ td_clone = td .clone ()
841+ ms_td = ms (td_clone )
842+
843+ torch .manual_seed (0 )
844+ np .random .seed (0 )
845+
859846 with _check_td_steady (ms_td ):
860847 loss_ms = loss_fn (ms_td )
848+ assert loss_fn .priority_key in ms_td .keys ()
849+
861850 with torch .no_grad ():
851+ torch .manual_seed (0 ) # log-prob is computed with a random action
852+ np .random .seed (0 )
862853 loss = loss_fn (td )
863854 if n == 0 :
864855 assert_allclose_td (td , ms_td .select (* list (td .keys ())))
@@ -870,10 +861,50 @@ def test_td3_batcher(
870861 else :
871862 with pytest .raises (AssertionError ):
872863 assert_allclose_td (loss , loss_ms )
864+
873865 sum ([item for _ , item in loss_ms .items ()]).backward ()
874- parameters = list (actor .parameters ()) + list (value .parameters ())
875- for p in parameters :
876- assert p .grad .norm () > 0.0
866+ named_parameters = loss_fn .named_parameters ()
867+ for name , p in named_parameters :
868+ assert p .grad .norm () > 0.0 , f"parameter { name } has null gradient"
869+
870+ # Check param update effect on targets
871+ target_actor = loss_fn .target_actor_network_params .clone ().values (
872+ include_nested = True , leaves_only = True
873+ )
874+ target_qvalue = loss_fn .target_qvalue_network_params .clone ().values (
875+ include_nested = True , leaves_only = True
876+ )
877+ for p in loss_fn .parameters ():
878+ p .data += torch .randn_like (p )
879+ target_actor2 = loss_fn .target_actor_network_params .clone ().values (
880+ include_nested = True , leaves_only = True
881+ )
882+ target_qvalue2 = loss_fn .target_qvalue_network_params .clone ().values (
883+ include_nested = True , leaves_only = True
884+ )
885+ if loss_fn .delay_actor :
886+ assert all ((p1 == p2 ).all () for p1 , p2 in zip (target_actor , target_actor2 ))
887+ else :
888+ assert not any (
889+ (p1 == p2 ).any () for p1 , p2 in zip (target_actor , target_actor2 )
890+ )
891+ if loss_fn .delay_qvalue :
892+ assert all (
893+ (p1 == p2 ).all () for p1 , p2 in zip (target_qvalue , target_qvalue2 )
894+ )
895+ else :
896+ assert not any (
897+ (p1 == p2 ).any () for p1 , p2 in zip (target_qvalue , target_qvalue2 )
898+ )
899+
900+ # check that policy is updated after parameter update
901+ actorp_set = set (actor .parameters ())
902+ loss_fnp_set = set (loss_fn .parameters ())
903+ assert len (actorp_set .intersection (loss_fnp_set )) == len (actorp_set )
904+ parameters = [p .clone () for p in actor .parameters ()]
905+ for p in loss_fn .parameters ():
906+ p .data += torch .randn_like (p )
907+ assert all ((p1 != p2 ).all () for p1 , p2 in zip (parameters , actor .parameters ()))
877908
878909
879910class TestSAC :
0 commit comments