@@ -422,21 +422,49 @@ def log_prob(
422422                if  dist .aggregate_probabilities  is  not   None :
423423                    aggregate_probabilities_inp  =  dist .aggregate_probabilities 
424424                else :
425-                     # TODO: warning 
425+                     warnings .warn (
426+                         f"aggregate_probabilities wasn't defined in the { type (self ).__name__ }   instance. " 
427+                         f"It couldn't be retrieved from the CompositeDistribution object either. " 
428+                         f"Currently, the aggregate_probability will be `True` in this case but in a future release " 
429+                         f"(v0.9) this will change and `aggregate_probabilities` will default to ``False`` such " 
430+                         f"that log_prob will return a tensordict with the log-prob values. To silence this warning, " 
431+                         f"pass `aggregate_probabilities` to the { type (self ).__name__ }   constructor, to the distribution kwargs " 
432+                         f"or to the log-prob method." ,
433+                         category = DeprecationWarning ,
434+                     )
426435                    aggregate_probabilities_inp  =  False 
427436            else :
428437                aggregate_probabilities_inp  =  aggregate_probabilities 
429438            if  inplace  is  None :
430439                if  dist .inplace  is  not   None :
431440                    inplace  =  dist .inplace 
432441                else :
433-                     # TODO: warning 
442+                     warnings .warn (
443+                         f"inplace wasn't defined in the { type (self ).__name__ }   instance. " 
444+                         f"It couldn't be retrieved from the CompositeDistribution object either. " 
445+                         f"Currently, the `inplace` will be `True` in this case but in a future release " 
446+                         f"(v0.9) this will change and `inplace` will default to ``False`` such " 
447+                         f"that log_prob will return a new tensordict containing only the log-prob values. To silence this warning, " 
448+                         f"pass `inplace` to the { type (self ).__name__ }   constructor, to the distribution kwargs " 
449+                         f"or to the log-prob method." ,
450+                         category = DeprecationWarning ,
451+                     )
434452                    inplace  =  True 
435453            if  include_sum  is  None :
436454                if  dist .include_sum  is  not   None :
437455                    include_sum  =  dist .include_sum 
438456                else :
439-                     # TODO: warning 
457+                     warnings .warn (
458+                         f"include_sum wasn't defined in the { type (self ).__name__ }   instance. " 
459+                         f"It couldn't be retrieved from the CompositeDistribution object either. " 
460+                         f"Currently, the `include_sum` will be `True` in this case but in a future release " 
461+                         f"(v0.9) this will change and `include_sum` will default to ``False`` such " 
462+                         f"that log_prob will return a new tensordict containing only the leaf log-prob values. " 
463+                         f"To silence this warning, " 
464+                         f"pass `include_sum` to the { type (self ).__name__ }   constructor, to the distribution kwargs " 
465+                         f"or to the log-prob method." ,
466+                         category = DeprecationWarning ,
467+                     )
440468                    include_sum  =  True 
441469            lp  =  dist .log_prob (
442470                tensordict ,
@@ -446,6 +474,7 @@ def log_prob(
446474            )
447475            if  is_tensor_collection (lp ) and  aggregate_probabilities  is  None :
448476                return  lp .get (dist .log_prob_key )
477+             return  lp 
449478        else :
450479            return  dist .log_prob (tensordict .get (self .out_keys [0 ]))
451480
@@ -1027,8 +1056,9 @@ def log_prob(
10271056    ):
10281057        """Returns the log-probability of the input tensordict. 
10291058
1030-         If `return_composite` is ``True`` and the distribution is a :class:`~tensordict.nn.CompositeDistribution`, 
1031-         this method will return the log-probability of the entire composite distribution. 
1059+         If `self.return_composite` is ``True`` and the distribution is a :class:`~tensordict.nn.CompositeDistribution`, 
1060+         or if any of :attr:`aggregate_probabilities`, :attr:`inplace` or :attr:`include_sum` this method will return 
1061+         the log-probability of the entire composite distribution. 
10321062
10331063        Otherwise, it will only consider the last probabilistic module in the sequence. 
10341064
@@ -1069,7 +1099,13 @@ def log_prob(
10691099            tensordict_inp  =  tensordict 
10701100        if  dist  is  None :
10711101            dist  =  self .get_dist (tensordict_inp )
1072-         if  self .return_composite  and  isinstance (dist , CompositeDistribution ):
1102+         return_composite  =  (
1103+             self .return_composite 
1104+             or  (aggregate_probabilities  is  not   None )
1105+             or  (inplace  is  not   None )
1106+             or  (include_sum  is  not   None )
1107+         )
1108+         if  return_composite  and  isinstance (dist , CompositeDistribution ):
10731109            # Check the values within the dist - if not set, choose defaults 
10741110            if  aggregate_probabilities  is  None :
10751111                if  self .aggregate_probabilities  is  not   None :
0 commit comments