From 1ad20f803ccf55870797e439e2a8198ed1ea7fb2 Mon Sep 17 00:00:00 2001 From: albert bou Date: Mon, 2 Oct 2023 18:34:06 +0200 Subject: [PATCH 1/2] pickable and tests --- test/test_transforms.py | 10 ++++++++++ torchrl/envs/transforms/transforms.py | 16 +++++++++++++++- 2 files changed, 25 insertions(+), 1 deletion(-) diff --git a/test/test_transforms.py b/test/test_transforms.py index 5299a72d854..31d5b99837f 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -6,6 +6,7 @@ import argparse import itertools +import pickle import sys from copy import copy from functools import partial @@ -7290,6 +7291,15 @@ def test_vecnorm_rollout(self, parallel, thr=0.2, N=200): env_t.close() self.SEED = 0 + def test_pickable(self, rb_type, sampler, writer, storage, size): + + transform = VecNorm() + serialized = pickle.dumps(transform) + transform2 = pickle.loads(serialized) + assert transform.__dict__.keys() == transform2.__dict__.keys() + for key in sorted(transform.__dict__.keys()): + assert isinstance(transform.__dict__[key], type(transform2.__dict__[key])) + def test_added_transforms_are_in_eval_mode_trivial(): base_env = ContinuousActionVecMockEnv() diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index 879432569f7..75b244f88d4 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -11,7 +11,7 @@ from copy import copy from functools import wraps from textwrap import indent -from typing import Any, List, Optional, OrderedDict, Sequence, Tuple, Union +from typing import Any, Dict, List, Optional, OrderedDict, Sequence, Tuple, Union import numpy as np @@ -4337,6 +4337,20 @@ def __repr__(self) -> str: f"eps={self.eps:4.4f}, keys={self.in_keys})" ) + def __getstate__(self) -> Dict[str, Any]: + state = self.__dict__.copy() + _lock = state.pop("lock", None) + if _lock is not None: + state["lock_placeholder"] = None + return state + + def __setstate__(self, state: Dict[str, Any]): + if "lock_placeholder" in state: + state.pop("lock_placeholder") + _lock = mp.Lock() + state["lock"] = _lock + self.__dict__.update(state) + class RewardSum(Transform): """Tracks episode cumulative rewards. From d9a85e892d19fb5eb8a38ae75f5d97f23f523844 Mon Sep 17 00:00:00 2001 From: albert bou Date: Mon, 2 Oct 2023 18:35:00 +0200 Subject: [PATCH 2/2] fix --- test/test_transforms.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_transforms.py b/test/test_transforms.py index 31d5b99837f..f82853f6f8c 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -7291,7 +7291,7 @@ def test_vecnorm_rollout(self, parallel, thr=0.2, N=200): env_t.close() self.SEED = 0 - def test_pickable(self, rb_type, sampler, writer, storage, size): + def test_pickable(self): transform = VecNorm() serialized = pickle.dumps(transform)