@@ -196,10 +196,10 @@ def __getitem__(
196196 index = [index ]
197197 return self .__class__ (matrix = self .get_matrix ()[index ])
198198
199- def compose (self , * others ) :
199+ def compose (self , * others : "Transform3d" ) -> "Transform3d" :
200200 """
201- Return a new Transform3d with the transforms to compose stored as
202- an internal list.
201+ Return a new Transform3d representing the composition of self with the
202+ given other transforms, which will be stored as an internal list.
203203
204204 Args:
205205 *others: Any number of Transform3d objects
@@ -216,7 +216,7 @@ def compose(self, *others):
216216 out ._transforms = self ._transforms + list (others )
217217 return out
218218
219- def get_matrix (self ):
219+ def get_matrix (self ) -> torch . Tensor :
220220 """
221221 Return a matrix which is the result of composing this transform
222222 with others stored in self.transforms. Where necessary transforms
@@ -240,13 +240,13 @@ def get_matrix(self):
240240 composed_matrix = _broadcast_bmm (composed_matrix , other_matrix )
241241 return composed_matrix
242242
243- def _get_matrix_inverse (self ):
243+ def _get_matrix_inverse (self ) -> torch . Tensor :
244244 """
245245 Return the inverse of self._matrix.
246246 """
247247 return torch .inverse (self ._matrix )
248248
249- def inverse (self , invert_composed : bool = False ):
249+ def inverse (self , invert_composed : bool = False ) -> "Transform3d" :
250250 """
251251 Returns a new Transform3d object that represents an inverse of the
252252 current transformation.
@@ -295,14 +295,24 @@ def inverse(self, invert_composed: bool = False):
295295
296296 return tinv
297297
298- def stack (self , * others ):
298+ def stack (self , * others : "Transform3d" ) -> "Transform3d" :
299+ """
300+ Return a new batched Transform3d representing the batch elements from
301+ self and all the given other transforms all batched together.
302+
303+ Args:
304+ *others: Any number of Transform3d objects
305+
306+ Returns:
307+ A new Transform3d.
308+ """
299309 transforms = [self ] + list (others )
300- matrix = torch .cat ([t ._matrix for t in transforms ], dim = 0 )
310+ matrix = torch .cat ([t .get_matrix () for t in transforms ], dim = 0 )
301311 out = Transform3d (dtype = self .dtype , device = self .device )
302312 out ._matrix = matrix
303313 return out
304314
305- def transform_points (self , points , eps : Optional [float ] = None ):
315+ def transform_points (self , points , eps : Optional [float ] = None ) -> torch . Tensor :
306316 """
307317 Use this transform to transform a set of 3D points. Assumes row major
308318 ordering of the input points.
@@ -347,7 +357,7 @@ def transform_points(self, points, eps: Optional[float] = None):
347357
348358 return points_out
349359
350- def transform_normals (self , normals ):
360+ def transform_normals (self , normals ) -> torch . Tensor :
351361 """
352362 Use this transform to transform a set of normal vectors.
353363
@@ -379,19 +389,19 @@ def transform_normals(self, normals):
379389
380390 return normals_out
381391
382- def translate (self , * args , ** kwargs ):
392+ def translate (self , * args , ** kwargs ) -> "Transform3d" :
383393 return self .compose (Translate (device = self .device , * args , ** kwargs ))
384394
385- def scale (self , * args , ** kwargs ):
395+ def scale (self , * args , ** kwargs ) -> "Transform3d" :
386396 return self .compose (Scale (device = self .device , * args , ** kwargs ))
387397
388- def rotate (self , * args , ** kwargs ):
398+ def rotate (self , * args , ** kwargs ) -> "Transform3d" :
389399 return self .compose (Rotate (device = self .device , * args , ** kwargs ))
390400
391- def rotate_axis_angle (self , * args , ** kwargs ):
401+ def rotate_axis_angle (self , * args , ** kwargs ) -> "Transform3d" :
392402 return self .compose (RotateAxisAngle (device = self .device , * args , ** kwargs ))
393403
394- def clone (self ):
404+ def clone (self ) -> "Transform3d" :
395405 """
396406 Deep copy of Transforms object. All internal tensors are cloned
397407 individually.
@@ -411,7 +421,7 @@ def to(
411421 device : Device ,
412422 copy : bool = False ,
413423 dtype : Optional [torch .dtype ] = None ,
414- ):
424+ ) -> "Transform3d" :
415425 """
416426 Match functionality of torch.Tensor.to()
417427 If copy = True or the self Tensor is on a different device, the
@@ -448,10 +458,10 @@ def to(
448458 ]
449459 return other
450460
451- def cpu (self ):
461+ def cpu (self ) -> "Transform3d" :
452462 return self .to ("cpu" )
453463
454- def cuda (self ):
464+ def cuda (self ) -> "Transform3d" :
455465 return self .to ("cuda" )
456466
457467
@@ -486,7 +496,7 @@ def __init__(
486496 mat [:, 3 , :3 ] = xyz
487497 self ._matrix = mat
488498
489- def _get_matrix_inverse (self ):
499+ def _get_matrix_inverse (self ) -> torch . Tensor :
490500 """
491501 Return the inverse of self._matrix.
492502 """
@@ -533,7 +543,7 @@ def __init__(
533543 mat [:, 2 , 2 ] = xyz [:, 2 ]
534544 self ._matrix = mat
535545
536- def _get_matrix_inverse (self ):
546+ def _get_matrix_inverse (self ) -> torch . Tensor :
537547 """
538548 Return the inverse of self._matrix.
539549 """
@@ -575,7 +585,7 @@ def __init__(
575585 mat [:, :3 , :3 ] = R
576586 self ._matrix = mat
577587
578- def _get_matrix_inverse (self ):
588+ def _get_matrix_inverse (self ) -> torch . Tensor :
579589 """
580590 Return the inverse of self._matrix.
581591 """
@@ -622,7 +632,7 @@ def __init__(
622632 super ().__init__ (device = angle .device , R = R )
623633
624634
625- def _handle_coord (c , dtype : torch .dtype , device : torch .device ):
635+ def _handle_coord (c , dtype : torch .dtype , device : torch .device ) -> torch . Tensor :
626636 """
627637 Helper function for _handle_input.
628638
@@ -649,7 +659,7 @@ def _handle_input(
649659 device : Optional [Device ],
650660 name : str ,
651661 allow_singleton : bool = False ,
652- ):
662+ ) -> torch . Tensor :
653663 """
654664 Helper function to handle parsing logic for building transforms. The output
655665 is always a tensor of shape (N, 3), but there are several types of allowed
@@ -707,7 +717,9 @@ def _handle_input(
707717 return xyz
708718
709719
710- def _handle_angle_input (x , dtype : torch .dtype , device : Optional [Device ], name : str ):
720+ def _handle_angle_input (
721+ x , dtype : torch .dtype , device : Optional [Device ], name : str
722+ ) -> torch .Tensor :
711723 """
712724 Helper function for building a rotation function using angles.
713725 The output is always of shape (N,).
@@ -725,7 +737,7 @@ def _handle_angle_input(x, dtype: torch.dtype, device: Optional[Device], name: s
725737 return _handle_coord (x , dtype , device_ )
726738
727739
728- def _broadcast_bmm (a , b ):
740+ def _broadcast_bmm (a , b ) -> torch . Tensor :
729741 """
730742 Batch multiply two matrices and broadcast if necessary.
731743
0 commit comments