Skip to content

Commit 98c57ee

Browse files
author
Vincent Moens
committed
[Feature] Force log_prob to return a tensordict when kwargs are passed to ProbabilisticTensorDictSequential.log_prob
ghstack-source-id: 326d076 Pull Request resolved: #1146
1 parent 2d37d92 commit 98c57ee

File tree

2 files changed

+64
-8
lines changed

2 files changed

+64
-8
lines changed

tensordict/nn/probabilistic.py

Lines changed: 42 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -422,21 +422,49 @@ def log_prob(
422422
if dist.aggregate_probabilities is not None:
423423
aggregate_probabilities_inp = dist.aggregate_probabilities
424424
else:
425-
# TODO: warning
425+
warnings.warn(
426+
f"aggregate_probabilities wasn't defined in the {type(self).__name__} instance. "
427+
f"It couldn't be retrieved from the CompositeDistribution object either. "
428+
f"Currently, the aggregate_probability will be `True` in this case but in a future release "
429+
f"(v0.9) this will change and `aggregate_probabilities` will default to ``False`` such "
430+
f"that log_prob will return a tensordict with the log-prob values. To silence this warning, "
431+
f"pass `aggregate_probabilities` to the {type(self).__name__} constructor, to the distribution kwargs "
432+
f"or to the log-prob method.",
433+
category=DeprecationWarning,
434+
)
426435
aggregate_probabilities_inp = False
427436
else:
428437
aggregate_probabilities_inp = aggregate_probabilities
429438
if inplace is None:
430439
if dist.inplace is not None:
431440
inplace = dist.inplace
432441
else:
433-
# TODO: warning
442+
warnings.warn(
443+
f"inplace wasn't defined in the {type(self).__name__} instance. "
444+
f"It couldn't be retrieved from the CompositeDistribution object either. "
445+
f"Currently, the `inplace` will be `True` in this case but in a future release "
446+
f"(v0.9) this will change and `inplace` will default to ``False`` such "
447+
f"that log_prob will return a new tensordict containing only the log-prob values. To silence this warning, "
448+
f"pass `inplace` to the {type(self).__name__} constructor, to the distribution kwargs "
449+
f"or to the log-prob method.",
450+
category=DeprecationWarning,
451+
)
434452
inplace = True
435453
if include_sum is None:
436454
if dist.include_sum is not None:
437455
include_sum = dist.include_sum
438456
else:
439-
# TODO: warning
457+
warnings.warn(
458+
f"include_sum wasn't defined in the {type(self).__name__} instance. "
459+
f"It couldn't be retrieved from the CompositeDistribution object either. "
460+
f"Currently, the `include_sum` will be `True` in this case but in a future release "
461+
f"(v0.9) this will change and `include_sum` will default to ``False`` such "
462+
f"that log_prob will return a new tensordict containing only the leaf log-prob values. "
463+
f"To silence this warning, "
464+
f"pass `include_sum` to the {type(self).__name__} constructor, to the distribution kwargs "
465+
f"or to the log-prob method.",
466+
category=DeprecationWarning,
467+
)
440468
include_sum = True
441469
lp = dist.log_prob(
442470
tensordict,
@@ -446,6 +474,7 @@ def log_prob(
446474
)
447475
if is_tensor_collection(lp) and aggregate_probabilities is None:
448476
return lp.get(dist.log_prob_key)
477+
return lp
449478
else:
450479
return dist.log_prob(tensordict.get(self.out_keys[0]))
451480

@@ -1027,8 +1056,9 @@ def log_prob(
10271056
):
10281057
"""Returns the log-probability of the input tensordict.
10291058
1030-
If `return_composite` is ``True`` and the distribution is a :class:`~tensordict.nn.CompositeDistribution`,
1031-
this method will return the log-probability of the entire composite distribution.
1059+
If `self.return_composite` is ``True`` and the distribution is a :class:`~tensordict.nn.CompositeDistribution`,
1060+
or if any of :attr:`aggregate_probabilities`, :attr:`inplace` or :attr:`include_sum` this method will return
1061+
the log-probability of the entire composite distribution.
10321062
10331063
Otherwise, it will only consider the last probabilistic module in the sequence.
10341064
@@ -1069,7 +1099,13 @@ def log_prob(
10691099
tensordict_inp = tensordict
10701100
if dist is None:
10711101
dist = self.get_dist(tensordict_inp)
1072-
if self.return_composite and isinstance(dist, CompositeDistribution):
1102+
return_composite = (
1103+
self.return_composite
1104+
or (aggregate_probabilities is not None)
1105+
or (inplace is not None)
1106+
or (include_sum is not None)
1107+
)
1108+
if return_composite and isinstance(dist, CompositeDistribution):
10731109
# Check the values within the dist - if not set, choose defaults
10741110
if aggregate_probabilities is None:
10751111
if self.aggregate_probabilities is not None:

test/test_nn.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2930,7 +2930,17 @@ def test_prob_module(self, interaction, return_log_prob, map_names):
29302930
assert key_logprob1 in sample
29312931
assert all(key in sample for key in module.out_keys)
29322932
sample_clone = sample.clone()
2933-
lp = module.log_prob(sample_clone)
2933+
with pytest.warns(
2934+
DeprecationWarning,
2935+
match="aggregate_probabilities wasn't defined in the ProbabilisticTensorDictModule",
2936+
), pytest.warns(
2937+
DeprecationWarning,
2938+
match="inplace wasn't defined in the ProbabilisticTensorDictModule",
2939+
), pytest.warns(
2940+
DeprecationWarning,
2941+
match="include_sum wasn't defined in the ProbabilisticTensorDictModule",
2942+
):
2943+
lp = module.log_prob(sample_clone)
29342944
assert isinstance(lp, torch.Tensor)
29352945
if return_log_prob:
29362946
torch.testing.assert_close(
@@ -3077,7 +3087,17 @@ def test_prob_module_seq(self, interaction, return_log_prob, ordereddict):
30773087
assert isinstance(dist, CompositeDistribution)
30783088

30793089
sample_clone = sample.clone()
3080-
lp = module.log_prob(sample_clone)
3090+
with pytest.warns(
3091+
DeprecationWarning,
3092+
match="aggregate_probabilities wasn't defined in the ProbabilisticTensorDictModule",
3093+
), pytest.warns(
3094+
DeprecationWarning,
3095+
match="inplace wasn't defined in the ProbabilisticTensorDictModule",
3096+
), pytest.warns(
3097+
DeprecationWarning,
3098+
match="include_sum wasn't defined in the ProbabilisticTensorDictModule",
3099+
):
3100+
lp = module.log_prob(sample_clone)
30813101

30823102
if return_log_prob:
30833103
torch.testing.assert_close(

0 commit comments

Comments
 (0)