Skip to content

Commit 55127fe

Browse files
committed
amend
1 parent 92f7637 commit 55127fe

File tree

4 files changed

+100
-1
lines changed

4 files changed

+100
-1
lines changed

examples/multiagent/mappo_ippo.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,7 @@ def train(cfg: "DictConfig"): # noqa: F821
174174
with torch.no_grad():
175175
loss_module.value_estimator(
176176
tensordict_data,
177-
params=loss_module.critic_params,
177+
params=loss_module.critic_network_params,
178178
target_params=loss_module.target_critic_params,
179179
)
180180
current_frames = tensordict_data.numel()

torchrl/objectives/a2c.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
# This source code is licensed under the MIT license found in the
44
# LICENSE file in the root directory of this source tree.
55
import contextlib
6+
import logging
67
import warnings
78
from copy import deepcopy
89
from dataclasses import dataclass
@@ -297,12 +298,44 @@ def functional(self):
297298

298299
@property
299300
def actor(self):
301+
logging.warning(
302+
f"{self.__class__.__name__}.actor is deprecated, use {self.__class__.__name__}.actor_network instead. This "
303+
"link will be removed in v0.4."
304+
)
300305
return self.actor_network
301306

302307
@property
303308
def critic(self):
309+
logging.warning(
310+
f"{self.__class__.__name__}.critic is deprecated, use {self.__class__.__name__}.critic_network instead. This "
311+
"link will be removed in v0.4."
312+
)
304313
return self.critic_network
305314

315+
@property
316+
def actor_params(self):
317+
logging.warning(
318+
f"{self.__class__.__name__}.actor_params is deprecated, use {self.__class__.__name__}.actor_network_params instead. This "
319+
"link will be removed in v0.4."
320+
)
321+
return self.actor_network_params
322+
323+
@property
324+
def critic_params(self):
325+
logging.warning(
326+
f"{self.__class__.__name__}.critic_params is deprecated, use {self.__class__.__name__}.critic_network_params instead. This "
327+
"link will be removed in v0.4."
328+
)
329+
return self.critic_network_params
330+
331+
@property
332+
def target_critic_params(self):
333+
logging.warning(
334+
f"{self.__class__.__name__}.target_critic_params is deprecated, use {self.__class__.__name__}.target_critic_network_params instead. This "
335+
"link will be removed in v0.4."
336+
)
337+
return self.target_critic_network_params
338+
306339
@property
307340
def in_keys(self):
308341
keys = [

torchrl/objectives/ppo.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from __future__ import annotations
66

77
import contextlib
8+
import logging
89

910
import math
1011
import warnings
@@ -345,12 +346,44 @@ def functional(self):
345346

346347
@property
347348
def actor(self):
349+
logging.warning(
350+
f"{self.__class__.__name__}.actor is deprecated, use {self.__class__.__name__}.actor_network instead. This "
351+
"link will be removed in v0.4."
352+
)
348353
return self.actor_network
349354

350355
@property
351356
def critic(self):
357+
logging.warning(
358+
f"{self.__class__.__name__}.critic is deprecated, use {self.__class__.__name__}.critic_network instead. This "
359+
"link will be removed in v0.4."
360+
)
352361
return self.critic_network
353362

363+
@property
364+
def actor_params(self):
365+
logging.warning(
366+
f"{self.__class__.__name__}.actor_params is deprecated, use {self.__class__.__name__}.actor_network_params instead. This "
367+
"link will be removed in v0.4."
368+
)
369+
return self.actor_network_params
370+
371+
@property
372+
def critic_params(self):
373+
logging.warning(
374+
f"{self.__class__.__name__}.critic_params is deprecated, use {self.__class__.__name__}.critic_network_params instead. This "
375+
"link will be removed in v0.4."
376+
)
377+
return self.critic_network_params
378+
379+
@property
380+
def target_critic_params(self):
381+
logging.warning(
382+
f"{self.__class__.__name__}.target_critic_params is deprecated, use {self.__class__.__name__}.target_critic_network_params instead. This "
383+
"link will be removed in v0.4."
384+
)
385+
return self.target_critic_network_params
386+
354387
def _set_in_keys(self):
355388
keys = [
356389
self.tensor_keys.action,

torchrl/objectives/reinforce.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from __future__ import annotations
66

77
import contextlib
8+
import logging
89
import warnings
910
from copy import deepcopy
1011
from dataclasses import dataclass
@@ -289,12 +290,44 @@ def functional(self):
289290

290291
@property
291292
def actor(self):
293+
logging.warning(
294+
f"{self.__class__.__name__}.actor is deprecated, use {self.__class__.__name__}.actor_network instead. This "
295+
"link will be removed in v0.4."
296+
)
292297
return self.actor_network
293298

294299
@property
295300
def critic(self):
301+
logging.warning(
302+
f"{self.__class__.__name__}.critic is deprecated, use {self.__class__.__name__}.critic_network instead. This "
303+
"link will be removed in v0.4."
304+
)
296305
return self.critic_network
297306

307+
@property
308+
def actor_params(self):
309+
logging.warning(
310+
f"{self.__class__.__name__}.actor_params is deprecated, use {self.__class__.__name__}.actor_network_params instead. This "
311+
"link will be removed in v0.4."
312+
)
313+
return self.actor_network_params
314+
315+
@property
316+
def critic_params(self):
317+
logging.warning(
318+
f"{self.__class__.__name__}.critic_params is deprecated, use {self.__class__.__name__}.critic_network_params instead. This "
319+
"link will be removed in v0.4."
320+
)
321+
return self.critic_network_params
322+
323+
@property
324+
def target_critic_params(self):
325+
logging.warning(
326+
f"{self.__class__.__name__}.target_critic_params is deprecated, use {self.__class__.__name__}.target_critic_network_params instead. This "
327+
"link will be removed in v0.4."
328+
)
329+
return self.target_critic_network_params
330+
298331
def _forward_value_estimator_keys(self, **kwargs) -> None:
299332
if self._value_estimator is not None:
300333
self._value_estimator.set_keys(

0 commit comments

Comments
 (0)