@@ -43,27 +43,8 @@ class DecoderFunctionBase(ReplaceableBase, torch.nn.Module):
4343 """
4444 Decoding function is a torch.nn.Module which takes the embedding of a location in
4545 space and transforms it into the required quantity (for example density and color).
46-
47- Members:
48- param_groups: dictionary where keys are names of individual parameters
49- or module members and values are the parameter group where the
50- parameter/member will be sorted to. "self" key is used to denote the
51- parameter group at the module level. Possible keys, including the "self" key
52- do not have to be defined. By default all parameters are put into "default"
53- parameter group and have the learning rate defined in the optimizer,
54- it can be overridden at the:
55- - module level with “self” key, all the parameters and child
56- module's parameters will be put to that parameter group
57- - member level, which is the same as if the `param_groups` in that
58- member has key=“self” and value equal to that parameter group.
59- This is useful if members do not have `param_groups`, for
60- example torch.nn.Linear.
61- - parameter level, parameter with the same name as the key
62- will be put to that parameter group.
6346 """
6447
65- param_groups : Dict [str , str ] = field (default_factory = lambda : {})
66-
6748 def __post_init__ (self ):
6849 super ().__init__ ()
6950
@@ -280,11 +261,30 @@ def forward(self, x: torch.Tensor, z: Optional[torch.Tensor] = None):
280261class MLPDecoder (DecoderFunctionBase ):
281262 """
282263 Decoding function which uses `MLPWithIputSkips` to convert the embedding to output.
283- If using Implicitron config system `input_dim` of the `network` is changed to the
284- value of `input_dim` member and `input_skips` is removed.
264+ The `input_dim` of the `network` is set from the value of `input_dim` member.
265+
266+ Members:
267+ input_dim: dimension of input.
268+ param_groups: dictionary where keys are names of individual parameters
269+ or module members and values are the parameter group where the
270+ parameter/member will be sorted to. "self" key is used to denote the
271+ parameter group at the module level. Possible keys, including the "self" key
272+ do not have to be defined. By default all parameters are put into "default"
273+ parameter group and have the learning rate defined in the optimizer,
274+ it can be overridden at the:
275+ - module level with “self” key, all the parameters and child
276+ module's parameters will be put to that parameter group
277+ - member level, which is the same as if the `param_groups` in that
278+ member has key=“self” and value equal to that parameter group.
279+ This is useful if members do not have `param_groups`, for
280+ example torch.nn.Linear.
281+ - parameter level, parameter with the same name as the key
282+ will be put to that parameter group.
283+ network_args: configuration for MLPWithInputSkips
285284 """
286285
287286 input_dim : int = 3
287+ param_groups : Dict [str , str ] = field (default_factory = lambda : {})
288288 network : MLPWithInputSkips
289289
290290 def __post_init__ (self ):
0 commit comments