@@ -514,32 +514,21 @@ def out_keys(self, values):
514
514
515
515
@dispatch
516
516
def forward (self , tensordict : TensorDictBase ) -> TensorDictBase :
517
- shape = None
518
- if tensordict .ndimension () > 1 :
519
- shape = tensordict .shape
520
- tensordict_reshape = tensordict .reshape (- 1 )
521
- else :
522
- tensordict_reshape = tensordict
523
-
524
- q_loss , metadata = self .q_loss (tensordict_reshape )
525
- cql_loss , cql_metadata = self .cql_loss (tensordict_reshape )
517
+ q_loss , metadata = self .q_loss (tensordict )
518
+ cql_loss , cql_metadata = self .cql_loss (tensordict )
526
519
if self .with_lagrange :
527
- alpha_prime_loss , alpha_prime_metadata = self .alpha_prime_loss (
528
- tensordict_reshape
529
- )
520
+ alpha_prime_loss , alpha_prime_metadata = self .alpha_prime_loss (tensordict )
530
521
metadata .update (alpha_prime_metadata )
531
- loss_actor_bc , bc_metadata = self .actor_bc_loss (tensordict_reshape )
532
- loss_actor , actor_metadata = self .actor_loss (tensordict_reshape )
522
+ loss_actor_bc , bc_metadata = self .actor_bc_loss (tensordict )
523
+ loss_actor , actor_metadata = self .actor_loss (tensordict )
533
524
loss_alpha , alpha_metadata = self .alpha_loss (actor_metadata )
534
525
metadata .update (bc_metadata )
535
526
metadata .update (cql_metadata )
536
527
metadata .update (actor_metadata )
537
528
metadata .update (alpha_metadata )
538
- tensordict_reshape .set (
529
+ tensordict .set (
539
530
self .tensor_keys .priority , metadata .pop ("td_error" ).detach ().max (0 ).values
540
531
)
541
- if shape :
542
- tensordict .update (tensordict_reshape .view (shape ))
543
532
out = {
544
533
"loss_actor" : loss_actor ,
545
534
"loss_actor_bc" : loss_actor_bc ,
@@ -682,7 +671,9 @@ def _get_value_v(self, tensordict, _alpha, actor_params, qval_params):
682
671
)
683
672
# take max over actions
684
673
state_action_value = state_action_value .reshape (
685
- self .num_qvalue_nets , tensordict .shape [0 ], self .num_random , - 1
674
+ torch .Size (
675
+ [self .num_qvalue_nets , * tensordict .shape , self .num_random , - 1 ]
676
+ )
686
677
).max (- 2 )[0 ]
687
678
# take min over qvalue nets
688
679
next_state_value = state_action_value .min (0 )[0 ]
@@ -739,14 +730,13 @@ def cql_loss(self, tensordict: TensorDictBase) -> Tuple[Tensor, dict]:
739
730
"This could be caused by calling cql_loss method before q_loss method."
740
731
)
741
732
742
- random_actions_tensor = (
743
- torch .FloatTensor (
744
- tensordict .shape [0 ] * self .num_random ,
733
+ random_actions_tensor = pred_q1 .new_empty (
734
+ (
735
+ * tensordict .shape [:- 1 ],
736
+ tensordict .shape [- 1 ] * self .num_random ,
745
737
tensordict [self .tensor_keys .action ].shape [- 1 ],
746
738
)
747
- .uniform_ (- 1 , 1 )
748
- .to (tensordict .device )
749
- )
739
+ ).uniform_ (- 1 , 1 )
750
740
curr_actions_td , curr_log_pis = self ._get_policy_actions (
751
741
tensordict .copy (),
752
742
self .actor_network_params ,
@@ -833,31 +823,31 @@ def filter_and_repeat(name, x):
833
823
q_new [0 ] - new_log_pis .detach ().unsqueeze (- 1 ),
834
824
q_curr [0 ] - curr_log_pis .detach ().unsqueeze (- 1 ),
835
825
],
836
- 1 ,
826
+ - 1 ,
837
827
)
838
828
cat_q2 = torch .cat (
839
829
[
840
830
q_random [1 ] - random_density ,
841
831
q_new [1 ] - new_log_pis .detach ().unsqueeze (- 1 ),
842
832
q_curr [1 ] - curr_log_pis .detach ().unsqueeze (- 1 ),
843
833
],
844
- 1 ,
834
+ - 1 ,
845
835
)
846
836
847
837
min_qf1_loss = (
848
- torch .logsumexp (cat_q1 / self .temperature , dim = 1 )
838
+ torch .logsumexp (cat_q1 / self .temperature , dim = - 1 )
849
839
* self .min_q_weight
850
840
* self .temperature
851
841
)
852
842
min_qf2_loss = (
853
- torch .logsumexp (cat_q2 / self .temperature , dim = 1 )
843
+ torch .logsumexp (cat_q2 / self .temperature , dim = - 1 )
854
844
* self .min_q_weight
855
845
* self .temperature
856
846
)
857
847
858
848
# Subtract the log likelihood of data
859
- cql_q1_loss = min_qf1_loss - pred_q1 * self .min_q_weight
860
- cql_q2_loss = min_qf2_loss - pred_q2 * self .min_q_weight
849
+ cql_q1_loss = min_qf1_loss . flatten () - pred_q1 * self .min_q_weight
850
+ cql_q2_loss = min_qf2_loss . flatten () - pred_q2 * self .min_q_weight
861
851
862
852
# write cql losses in tensordict for alpha prime loss
863
853
tensordict .set (self .tensor_keys .cql_q1_loss , cql_q1_loss )
@@ -1080,9 +1070,9 @@ def __init__(
1080
1070
self .loss_function = loss_function
1081
1071
if action_space is None :
1082
1072
# infer from value net
1083
- try :
1073
+ if hasattr ( value_network , "action_space" ) :
1084
1074
action_space = value_network .spec
1085
- except AttributeError :
1075
+ else :
1086
1076
# let's try with action_space then
1087
1077
try :
1088
1078
action_space = value_network .action_space
@@ -1205,8 +1195,6 @@ def value_loss(
1205
1195
with torch .no_grad ():
1206
1196
td_error = (pred_val_index - target_value ).pow (2 )
1207
1197
td_error = td_error .unsqueeze (- 1 )
1208
- if tensordict .device is not None :
1209
- td_error = td_error .to (tensordict .device )
1210
1198
1211
1199
tensordict .set (
1212
1200
self .tensor_keys .priority ,
0 commit comments