Skip to content

Commit 5d1b995

Browse files
committed
take off not needed actorcritic wrapper
1 parent f7c315d commit 5d1b995

File tree

2 files changed

+95
-58
lines changed

2 files changed

+95
-58
lines changed

test/test_cost.py

Lines changed: 81 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -737,13 +737,15 @@ def _create_seq_mock_data_td3(
737737

738738
@pytest.mark.skipif(not _has_functorch, reason="functorch not installed")
739739
@pytest.mark.parametrize("device", get_available_devices())
740-
@pytest.mark.parametrize("delay_actor, delay_value", [(False, False), (True, True)])
740+
@pytest.mark.parametrize(
741+
"delay_actor, delay_qvalue", [(False, False), (True, True)]
742+
)
741743
@pytest.mark.parametrize("policy_noise", [0.1, 1.0])
742744
@pytest.mark.parametrize("noise_clip", [0.1, 1.0])
743745
def test_td3(
744746
self,
745747
delay_actor,
746-
delay_value,
748+
delay_qvalue,
747749
device,
748750
policy_noise,
749751
noise_clip,
@@ -760,11 +762,19 @@ def test_td3(
760762
policy_noise=policy_noise,
761763
noise_clip=noise_clip,
762764
delay_actor=delay_actor,
763-
delay_value=delay_value,
765+
delay_qvalue=delay_qvalue,
764766
)
765767
with _check_td_steady(td):
766768
loss = loss_fn(td)
767769

770+
assert all(
771+
(p.grad is None) or (p.grad == 0).all()
772+
for p in loss_fn.qvalue_network_params.values(True, True)
773+
)
774+
assert all(
775+
(p.grad is None) or (p.grad == 0).all()
776+
for p in loss_fn.actor_network_params.values(True, True)
777+
)
768778
# check that losses are independent
769779
for k in loss.keys():
770780
if not k.startswith("loss"):
@@ -773,71 +783,43 @@ def test_td3(
773783
if k == "loss_actor":
774784
assert all(
775785
(p.grad is None) or (p.grad == 0).all()
776-
for p in loss_fn.value_network_params
786+
for p in loss_fn.qvalue_network_params.values(True, True)
777787
)
778788
assert not any(
779789
(p.grad is None) or (p.grad == 0).all()
780-
for p in loss_fn.actor_network_params
790+
for p in loss_fn.actor_network_params.values(True, True)
781791
)
782792
elif k == "loss_qvalue":
783793
assert all(
784794
(p.grad is None) or (p.grad == 0).all()
785-
for p in loss_fn.actor_network_params
795+
for p in loss_fn.actor_network_params.values(True, True)
786796
)
787797
assert not any(
788798
(p.grad is None) or (p.grad == 0).all()
789-
for p in loss_fn.value_network_params
799+
for p in loss_fn.qvalue_network_params.values(True, True)
790800
)
791801
else:
792802
raise NotImplementedError(k)
793803
loss_fn.zero_grad()
794804

795-
# check overall grad
796805
sum([item for _, item in loss.items()]).backward()
797-
parameters = list(actor.parameters()) + list(value.parameters())
798-
for p in parameters:
799-
assert p.grad.norm() > 0.0
806+
named_parameters = list(loss_fn.named_parameters())
807+
named_buffers = list(loss_fn.named_buffers())
800808

801-
# Check param update effect on targets
802-
target_actor = [p.clone() for p in loss_fn.target_actor_network_params]
803-
target_value = [p.clone() for p in loss_fn.target_value_network_params]
804-
for p in loss_fn.parameters():
805-
p.data += torch.randn_like(p)
806-
target_actor2 = [p.clone() for p in loss_fn.target_actor_network_params]
807-
target_value2 = [p.clone() for p in loss_fn.target_value_network_params]
808-
if loss_fn.delay_actor:
809-
assert all((p1 == p2).all() for p1, p2 in zip(target_actor, target_actor2))
810-
else:
811-
assert not any(
812-
(p1 == p2).any() for p1, p2 in zip(target_actor, target_actor2)
813-
)
814-
if loss_fn.delay_value:
815-
assert all((p1 == p2).all() for p1, p2 in zip(target_value, target_value2))
816-
else:
817-
assert not any(
818-
(p1 == p2).any() for p1, p2 in zip(target_value, target_value2)
819-
)
809+
assert len({p for n, p in named_parameters}) == len(list(named_parameters))
810+
assert len({p for n, p in named_buffers}) == len(list(named_buffers))
820811

821-
# check that policy is updated after parameter update
822-
parameters = [p.clone() for p in actor.parameters()]
823-
for p in loss_fn.parameters():
824-
p.data += torch.randn_like(p)
825-
assert all((p1 != p2).all() for p1, p2 in zip(parameters, actor.parameters()))
812+
for name, p in named_parameters:
813+
assert p.grad.norm() > 0.0, f"parameter {name} has a null gradient"
826814

827815
@pytest.mark.skipif(not _has_functorch, reason="functorch not installed")
828816
@pytest.mark.parametrize("n", list(range(4)))
829817
@pytest.mark.parametrize("device", get_available_devices())
830-
@pytest.mark.parametrize("delay_actor,delay_value", [(False, False), (True, True)])
818+
@pytest.mark.parametrize("delay_actor,delay_qvalue", [(False, False), (True, True)])
831819
@pytest.mark.parametrize("policy_noise", [0.1, 1.0])
832820
@pytest.mark.parametrize("noise_clip", [0.1, 1.0])
833821
def test_td3_batcher(
834-
self,
835-
n,
836-
delay_actor,
837-
delay_value,
838-
device,
839-
policy_noise,
840-
noise_clip,
822+
self, n, delay_actor, delay_qvalue, device, policy_noise, noise_clip, gamma=0.9
841823
):
842824
torch.manual_seed(self.seed)
843825
actor = self._create_mock_actor(device=device)
@@ -847,18 +829,27 @@ def test_td3_batcher(
847829
actor,
848830
value,
849831
gamma=0.9,
850-
loss_function="l2",
851832
policy_noise=policy_noise,
852833
noise_clip=noise_clip,
834+
delay_qvalue=delay_qvalue,
853835
delay_actor=delay_actor,
854-
delay_value=delay_value,
855836
)
856837

857-
ms = MultiStep(gamma=0.9, n_steps_max=n).to(device)
858-
ms_td = ms(td.clone())
838+
ms = MultiStep(gamma=gamma, n_steps_max=n).to(device)
839+
840+
td_clone = td.clone()
841+
ms_td = ms(td_clone)
842+
843+
torch.manual_seed(0)
844+
np.random.seed(0)
845+
859846
with _check_td_steady(ms_td):
860847
loss_ms = loss_fn(ms_td)
848+
assert loss_fn.priority_key in ms_td.keys()
849+
861850
with torch.no_grad():
851+
torch.manual_seed(0) # log-prob is computed with a random action
852+
np.random.seed(0)
862853
loss = loss_fn(td)
863854
if n == 0:
864855
assert_allclose_td(td, ms_td.select(*list(td.keys())))
@@ -870,10 +861,50 @@ def test_td3_batcher(
870861
else:
871862
with pytest.raises(AssertionError):
872863
assert_allclose_td(loss, loss_ms)
864+
873865
sum([item for _, item in loss_ms.items()]).backward()
874-
parameters = list(actor.parameters()) + list(value.parameters())
875-
for p in parameters:
876-
assert p.grad.norm() > 0.0
866+
named_parameters = loss_fn.named_parameters()
867+
for name, p in named_parameters:
868+
assert p.grad.norm() > 0.0, f"parameter {name} has null gradient"
869+
870+
# Check param update effect on targets
871+
target_actor = loss_fn.target_actor_network_params.clone().values(
872+
include_nested=True, leaves_only=True
873+
)
874+
target_qvalue = loss_fn.target_qvalue_network_params.clone().values(
875+
include_nested=True, leaves_only=True
876+
)
877+
for p in loss_fn.parameters():
878+
p.data += torch.randn_like(p)
879+
target_actor2 = loss_fn.target_actor_network_params.clone().values(
880+
include_nested=True, leaves_only=True
881+
)
882+
target_qvalue2 = loss_fn.target_qvalue_network_params.clone().values(
883+
include_nested=True, leaves_only=True
884+
)
885+
if loss_fn.delay_actor:
886+
assert all((p1 == p2).all() for p1, p2 in zip(target_actor, target_actor2))
887+
else:
888+
assert not any(
889+
(p1 == p2).any() for p1, p2 in zip(target_actor, target_actor2)
890+
)
891+
if loss_fn.delay_qvalue:
892+
assert all(
893+
(p1 == p2).all() for p1, p2 in zip(target_qvalue, target_qvalue2)
894+
)
895+
else:
896+
assert not any(
897+
(p1 == p2).any() for p1, p2 in zip(target_qvalue, target_qvalue2)
898+
)
899+
900+
# check that policy is updated after parameter update
901+
actorp_set = set(actor.parameters())
902+
loss_fnp_set = set(loss_fn.parameters())
903+
assert len(actorp_set.intersection(loss_fnp_set)) == len(actorp_set)
904+
parameters = [p.clone() for p in actor.parameters()]
905+
for p in loss_fn.parameters():
906+
p.data += torch.randn_like(p)
907+
assert all((p1 != p2).all() for p1, p2 in zip(parameters, actor.parameters()))
877908

878909

879910
class TestSAC:

torchrl/objectives/td3.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -42,12 +42,12 @@ class TD3Loss(LossModule):
4242
`"td_error"`.
4343
loss_function (str, optional): loss function to be used for the Q-value. Can be one of `"smooth_l1"`, "l2",
4444
"l1", Default is "smooth_l1".
45+
delay_actor (bool, optional): whether to separate the target actor networks from the actor networks used for
46+
data collection. Default is :obj:`False`.
4547
delay_qvalue (bool, optional): Whether to separate the target Q value networks from the Q value networks used
4648
for data collection. Default is :obj:`False`.
4749
"""
4850

49-
delay_actor: bool = False
50-
5151
def __init__(
5252
self,
5353
actor_network: SafeModule,
@@ -58,28 +58,33 @@ def __init__(
5858
noise_clip: float = 0.5,
5959
priotity_key: str = "td_error",
6060
loss_function: str = "smooth_l1",
61-
delay_qvalue: bool = True,
62-
):
61+
delay_actor: bool = False,
62+
delay_qvalue: bool = False,
63+
) -> None:
6364
if not _has_functorch:
6465
raise ImportError(
6566
f"Failed to import functorch with error message:\n{FUNCTORCH_ERR}"
6667
)
6768

6869
super().__init__()
70+
71+
self.delay_actor = delay_actor
72+
self.delay_qvalue = delay_qvalue
73+
6974
self.convert_to_functional(
7075
actor_network,
7176
"actor_network",
7277
create_target_params=self.delay_actor,
7378
)
7479

75-
self.delay_qvalue = delay_qvalue
7680
self.convert_to_functional(
7781
qvalue_network,
7882
"qvalue_network",
7983
num_qvalue_nets,
8084
create_target_params=self.delay_qvalue,
8185
compare_against=list(actor_network.parameters()),
8286
)
87+
8388
self.num_qvalue_nets = num_qvalue_nets
8489
self.register_buffer("gamma", torch.tensor(gamma))
8590
self.priority_key = priotity_key
@@ -203,14 +208,15 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
203208
f"QVal and actor loss have different shape: {loss_qval.shape} and {loss_actor.shape}"
204209
)
205210
td_out = TensorDict(
206-
{
211+
source={
207212
"loss_actor": loss_actor.mean(),
208213
"loss_qvalue": loss_qval.mean(),
214+
"pred_value": pred_val.mean().detach(),
209215
"state_action_value_actor": state_action_value_actor.mean().detach(),
210-
"next.state_value": next_state_value.mean().detach(),
216+
"next_state_value": next_state_value.mean().detach(),
211217
"target_value": target_value.mean().detach(),
212218
},
213-
[],
219+
batch_size=[],
214220
)
215221

216222
return td_out

0 commit comments

Comments
 (0)