@@ -85,7 +85,7 @@ class Volumes:
8585 are linearly interpolated over the spatial dimensions of the volume.
8686 - Note that the convention is the same as for the 5D version of the
8787 `torch.nn.functional.grid_sample` function called with
88- `align_corners==True` .
88+ the same value of `align_corners` argument .
8989 - Note that the local coordinate convention of `Volumes`
9090 (+X = left to right, +Y = top to bottom, +Z = away from the user)
9191 is *different* from the world coordinate convention of the
@@ -143,7 +143,7 @@ class Volumes:
143143 torch.nn.functional.grid_sample(
144144 v.densities(),
145145 v.get_coord_grid(world_coordinates=False),
146- align_corners=True ,
146+ align_corners=align_corners ,
147147 ) == v.densities(),
148148
149149 i.e. sampling the volume at trivial local coordinates
@@ -157,6 +157,7 @@ def __init__(
157157 features : Optional [_TensorBatch ] = None ,
158158 voxel_size : _VoxelSize = 1.0 ,
159159 volume_translation : _Translation = (0.0 , 0.0 , 0.0 ),
160+ align_corners : bool = True ,
160161 ) -> None :
161162 """
162163 Args:
@@ -186,6 +187,10 @@ def __init__(
186187 b) a Tensor of shape (3,)
187188 c) a Tensor of shape (minibatch, 3)
188189 d) a Tensor of shape (1,) (square voxels)
190+ **align_corners**: If set (default), the coordinates of the corner voxels are
191+ exactly −1 or +1 in the local coordinate system. Otherwise, the coordinates
192+ correspond to the centers of the corner voxels. Cf. the namesake argument to
193+ `torch.nn.functional.grid_sample`.
189194 """
190195
191196 # handle densities
@@ -206,6 +211,7 @@ def __init__(
206211 voxel_size = voxel_size ,
207212 volume_translation = volume_translation ,
208213 device = self .device ,
214+ align_corners = align_corners ,
209215 )
210216
211217 # handle features
@@ -336,6 +342,13 @@ def features_list(self) -> List[torch.Tensor]:
336342 return None
337343 return self ._features_densities_list (features_ )
338344
345+ def get_align_corners (self ) -> bool :
346+ """
347+ Return whether the corners of the voxels should be aligned with the
348+ image pixels.
349+ """
350+ return self .locator ._align_corners
351+
339352 def _features_densities_list (self , x : torch .Tensor ) -> List [torch .Tensor ]:
340353 """
341354 Retrieve the list representation of features/densities.
@@ -576,7 +589,7 @@ class VolumeLocator:
576589 are linearly interpolated over the spatial dimensions of the volume.
577590 - Note that the convention is the same as for the 5D version of the
578591 `torch.nn.functional.grid_sample` function called with
579- `align_corners==True` .
592+ the same value of `align_corners` argument .
580593 - Note that the local coordinate convention of `VolumeLocator`
581594 (+X = left to right, +Y = top to bottom, +Z = away from the user)
582595 is *different* from the world coordinate convention of the
@@ -634,7 +647,7 @@ class VolumeLocator:
634647 torch.nn.functional.grid_sample(
635648 v.densities(),
636649 v.get_coord_grid(world_coordinates=False),
637- align_corners=True ,
650+ align_corners=align_corners ,
638651 ) == v.densities(),
639652
640653 i.e. sampling the volume at trivial local coordinates
@@ -651,6 +664,7 @@ def __init__(
651664 device : torch .device ,
652665 voxel_size : _VoxelSize = 1.0 ,
653666 volume_translation : _Translation = (0.0 , 0.0 , 0.0 ),
667+ align_corners : bool = True ,
654668 ):
655669 """
656670 **batch_size** : Batch size of the underlying grids
@@ -674,15 +688,21 @@ def __init__(
674688 b) a Tensor of shape (3,)
675689 c) a Tensor of shape (minibatch, 3)
676690 d) a Tensor of shape (1,) (square voxels)
691+ **align_corners**: If set (default), the coordinates of the corner voxels are
692+ exactly −1 or +1 in the local coordinate system. Otherwise, the coordinates
693+ correspond to the centers of the corner voxels. Cf. the namesake argument to
694+ `torch.nn.functional.grid_sample`.
677695 """
678696 self .device = device
679697 self ._batch_size = batch_size
680698 self ._grid_sizes = self ._convert_grid_sizes2tensor (grid_sizes )
681699 self ._resolution = tuple (torch .max (self ._grid_sizes .cpu (), dim = 0 ).values )
700+ self ._align_corners = align_corners
682701
683702 # set the local_to_world transform
684703 self ._set_local_to_world_transform (
685- voxel_size = voxel_size , volume_translation = volume_translation
704+ voxel_size = voxel_size ,
705+ volume_translation = volume_translation ,
686706 )
687707
688708 def _convert_grid_sizes2tensor (
@@ -806,8 +826,17 @@ def _calculate_coordinate_grid(
806826 grid_sizes = self .get_grid_sizes ()
807827
808828 # generate coordinate axes
829+ def corner_coord_adjustment (r ):
830+ return 0.0 if self ._align_corners else 1.0 / r
831+
809832 vol_axes = [
810- torch .linspace (- 1.0 , 1.0 , r , dtype = torch .float32 , device = self .device )
833+ torch .linspace (
834+ - 1.0 + corner_coord_adjustment (r ),
835+ 1.0 - corner_coord_adjustment (r ),
836+ r ,
837+ dtype = torch .float32 ,
838+ device = self .device ,
839+ )
811840 for r in (de , he , wi )
812841 ]
813842
0 commit comments