88
99import  torch 
1010
11+ from  ..common .workaround  import  symeig3x3 
1112from  .utils  import  convert_pointclouds_to_tensor , get_point_covariances 
1213
1314
@@ -19,6 +20,8 @@ def estimate_pointcloud_normals(
1920    pointclouds : Union [torch .Tensor , "Pointclouds" ],
2021    neighborhood_size : int  =  50 ,
2122    disambiguate_directions : bool  =  True ,
23+     * ,
24+     use_symeig_workaround : bool  =  True ,
2225) ->  torch .Tensor :
2326    """ 
2427    Estimates the normals of a batch of `pointclouds`. 
@@ -33,6 +36,8 @@ def estimate_pointcloud_normals(
3336        geometry around each point. 
3437      **disambiguate_directions**: If `True`, uses the algorithm from [1] to 
3538        ensure sign consistency of the normals of neighboring points. 
39+       **use_symeig_workaround**: If `True`, uses a custom eigenvalue 
40+         calculation. 
3641
3742    Returns: 
3843      **normals**: A tensor of normals for each input point 
@@ -48,6 +53,7 @@ def estimate_pointcloud_normals(
4853        pointclouds ,
4954        neighborhood_size = neighborhood_size ,
5055        disambiguate_directions = disambiguate_directions ,
56+         use_symeig_workaround = use_symeig_workaround ,
5157    )
5258
5359    # the normals correspond to the first vector of each local coord frame 
@@ -60,6 +66,8 @@ def estimate_pointcloud_local_coord_frames(
6066    pointclouds : Union [torch .Tensor , "Pointclouds" ],
6167    neighborhood_size : int  =  50 ,
6268    disambiguate_directions : bool  =  True ,
69+     * ,
70+     use_symeig_workaround : bool  =  True ,
6371) ->  Tuple [torch .Tensor , torch .Tensor ]:
6472    """ 
6573    Estimates the principal directions of curvature (which includes normals) 
@@ -88,6 +96,8 @@ def estimate_pointcloud_local_coord_frames(
8896        geometry around each point. 
8997      **disambiguate_directions**: If `True`, uses the algorithm from [1] to 
9098        ensure sign consistency of the normals of neighboring points. 
99+       **use_symeig_workaround**: If `True`, uses a custom eigenvalue 
100+         calculation. 
91101
92102    Returns: 
93103      **curvatures**: The three principal curvatures of each point 
@@ -133,7 +143,10 @@ def estimate_pointcloud_local_coord_frames(
133143    # eigenvectors (=principal directions) in an ascending order of their 
134144    # corresponding eigenvalues, while the smallest eigenvalue's eigenvector 
135145    # corresponds to the normal direction 
136-     curvatures , local_coord_frames  =  torch .symeig (cov , eigenvectors = True )
146+     if  use_symeig_workaround :
147+         curvatures , local_coord_frames  =  symeig3x3 (cov , eigenvectors = True )
148+     else :
149+         curvatures , local_coord_frames  =  torch .symeig (cov , eigenvectors = True )
137150
138151    # disambiguate the directions of individual principal vectors 
139152    if  disambiguate_directions :
0 commit comments