@@ -65,7 +65,7 @@ class VoxelGridBase(ReplaceableBase, torch.nn.Module):
6565 padding : str = "zeros"
6666 mode : str = "bilinear"
6767 n_features : int = 1
68- resolution : Tuple [int , int , int ] = (64 , 64 , 64 )
68+ resolution : Tuple [int , int , int ] = (128 , 128 , 128 )
6969
7070 def __post_init__ (self ):
7171 super ().__init__ ()
@@ -507,8 +507,7 @@ class VoxelGridModule(Configurable, torch.nn.Module):
507507 voxel_grid_class_type : str = "FullResolutionVoxelGrid"
508508 voxel_grid : VoxelGridBase
509509
510- # pyre-fixme[8]: Attribute has type `Tuple[float, float, float]`; used as `float`.
511- extents : Tuple [float , float , float ] = 1.0
510+ extents : Tuple [float , float , float ] = (1.0 , 1.0 , 1.0 )
512511 translation : Tuple [float , float , float ] = (0.0 , 0.0 , 0.0 )
513512
514513 init_std : float = 0.1
@@ -552,13 +551,28 @@ def forward(self, points: torch.Tensor) -> torch.Tensor:
552551 grid_sizes = (2 , 2 , 2 ),
553552 # The locator object uses (x, y, z) convention for the
554553 # voxel size and translation.
555- voxel_size = self .extents ,
556- volume_translation = self .translation ,
557- # pyre-fixme [29]: `Union[BoundMethod[typing.Callable(torch._C._TensorBase...
554+ voxel_size = tuple ( self .extents ) ,
555+ volume_translation = tuple ( self .translation ) ,
556+ # pyre-ignore [29]
558557 device = next (self .params .values ()).device ,
559558 )
560559 # pyre-fixme[29]: `Union[torch._tensor.Tensor,
561560 # torch.nn.modules.module.Module]` is not a function.
562561 grid_values = self .voxel_grid .values_type (** self .params )
563562 # voxel grids operate with extra n_grids dimension, which we fix to one
564563 return self .voxel_grid .evaluate_world (points [None ], grid_values , locator )[0 ]
564+
565+ @staticmethod
566+ def get_output_dim (args : DictConfig ) -> int :
567+ """
568+ Utility to help predict the shape of the output of `forward`.
569+
570+ Args:
571+ args: DictConfig which would be used to initialize the object
572+ Returns:
573+ int: the length of the last dimension of the output tensor
574+ """
575+ grid = registry .get (VoxelGridBase , args ["voxel_grid_class_type" ])
576+ return grid .get_output_dim (
577+ args ["voxel_grid_" + args ["voxel_grid_class_type" ] + "_args" ]
578+ )
0 commit comments