From 558c597a04c8db293ca786338b653338ebb8eae6 Mon Sep 17 00:00:00 2001 From: Louis Faury Date: Fri, 30 May 2025 15:24:04 +0200 Subject: [PATCH 1/3] [BugFix] Categorical spec samples the right dtype when masked --- torchrl/data/tensor_specs.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/torchrl/data/tensor_specs.py b/torchrl/data/tensor_specs.py index a7dd608bf8c..e0abc446b1b 100644 --- a/torchrl/data/tensor_specs.py +++ b/torchrl/data/tensor_specs.py @@ -1964,7 +1964,11 @@ def rand(self, shape: torch.Size = None) -> torch.Tensor: else: mask_flat = mask shape_out = mask.shape[:-1] - m = torch.multinomial(mask_flat.float(), 1).reshape(shape_out) + m = ( + torch.multinomial(mask_flat.float(), 1) + .reshape(shape_out) + .to(self.dtype) + ) out = torch.nn.functional.one_hot(m, n).to(self.dtype) # torch.zeros((*shape, self.space.n), device=self.device, dtype=self.dtype) # out.scatter_(-1, m, 1) @@ -3926,7 +3930,7 @@ def rand(self, shape: torch.Size = None) -> torch.Tensor: "The last dimension of the mask must match the number of action allowed by the " f"Categorical spec. Got mask.shape={self.mask.shape} and n={n}." ) - out = torch.multinomial(mask_flat.float(), 1).reshape(shape_out) + out = torch.multinomial(mask_flat.float(), 1).reshape(shape_out).to(self.dtype) return out def index( From c48a106cac055dcfdaaeefe7d9b30fb44a8d884e Mon Sep 17 00:00:00 2001 From: Louis Faury Date: Fri, 30 May 2025 15:26:22 +0200 Subject: [PATCH 2/3] Revert --- torchrl/data/tensor_specs.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/torchrl/data/tensor_specs.py b/torchrl/data/tensor_specs.py index e0abc446b1b..a1f5ce53765 100644 --- a/torchrl/data/tensor_specs.py +++ b/torchrl/data/tensor_specs.py @@ -1964,11 +1964,7 @@ def rand(self, shape: torch.Size = None) -> torch.Tensor: else: mask_flat = mask shape_out = mask.shape[:-1] - m = ( - torch.multinomial(mask_flat.float(), 1) - .reshape(shape_out) - .to(self.dtype) - ) + m = torch.multinomial(mask_flat.float(), 1).reshape(shape_out) out = torch.nn.functional.one_hot(m, n).to(self.dtype) # torch.zeros((*shape, self.space.n), device=self.device, dtype=self.dtype) # out.scatter_(-1, m, 1) From 6e87a6d2dcdd8e870203d6a5f16e88278c2f7714 Mon Sep 17 00:00:00 2001 From: Louis Faury Date: Fri, 30 May 2025 15:36:03 +0200 Subject: [PATCH 3/3] Dtype test --- test/test_specs.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/test/test_specs.py b/test/test_specs.py index 98213508160..d984db64c3a 100644 --- a/test/test_specs.py +++ b/test/test_specs.py @@ -1331,6 +1331,14 @@ def test_categorical_action_spec_rand(self): sample = [sum(sample == i) for i in range(10)] assert chisquare(sample).pvalue > 0.1 + @pytest.mark.parametrize("dtype", [torch.int, torch.int32, torch.int64]) + def test_categorical_action_spec_rand_masked_right_dtype(self, dtype: torch.dtype): + torch.manual_seed(1) + action_spec = Categorical(2, dtype=dtype) + action_spec.update_mask(torch.tensor([True, False])) + sample = action_spec.rand() + assert sample.dtype == dtype + def test_mult_discrete_action_spec_rand(self): torch.manual_seed(0) ns = (10, 5)