@@ -421,9 +421,9 @@ def test_chamfer_pointcloud_object_withnormals(self):
421421            ("mean" , "mean" ),
422422            ("sum" , None ),
423423            ("mean" , None ),
424+             (None , None ),
424425        ]
425-         for  (point_reduction , batch_reduction ) in  reductions :
426- 
426+         for  point_reduction , batch_reduction  in  reductions :
427427            # Reinitialize all the tensors so that the 
428428            # backward pass can be computed. 
429429            points_normals  =  TestChamfer .init_pointclouds (
@@ -450,24 +450,52 @@ def test_chamfer_pointcloud_object_withnormals(self):
450450                batch_reduction = batch_reduction ,
451451            )
452452
453-             self .assertClose (cham_cloud , cham_tensor )
454-             self .assertClose (norm_cloud , norm_tensor )
455-             self ._check_gradients (
456-                 cham_tensor ,
457-                 norm_tensor ,
458-                 cham_cloud ,
459-                 norm_cloud ,
460-                 points_normals .cloud1 .points_list (),
461-                 points_normals .p1 ,
462-                 points_normals .cloud2 .points_list (),
463-                 points_normals .p2 ,
464-                 points_normals .cloud1 .normals_list (),
465-                 points_normals .n1 ,
466-                 points_normals .cloud2 .normals_list (),
467-                 points_normals .n2 ,
468-                 points_normals .p1_lengths ,
469-                 points_normals .p2_lengths ,
470-             )
453+             if  point_reduction  is  None :
454+                 cham_tensor_bidirectional  =  torch .hstack (
455+                     [cham_tensor [0 ], cham_tensor [1 ]]
456+                 )
457+                 norm_tensor_bidirectional  =  torch .hstack (
458+                     [norm_tensor [0 ], norm_tensor [1 ]]
459+                 )
460+                 cham_cloud_bidirectional  =  torch .hstack ([cham_cloud [0 ], cham_cloud [1 ]])
461+                 norm_cloud_bidirectional  =  torch .hstack ([norm_cloud [0 ], norm_cloud [1 ]])
462+                 self .assertClose (cham_cloud_bidirectional , cham_tensor_bidirectional )
463+                 self .assertClose (norm_cloud_bidirectional , norm_tensor_bidirectional )
464+                 self ._check_gradients (
465+                     cham_tensor_bidirectional ,
466+                     norm_tensor_bidirectional ,
467+                     cham_cloud_bidirectional ,
468+                     norm_cloud_bidirectional ,
469+                     points_normals .cloud1 .points_list (),
470+                     points_normals .p1 ,
471+                     points_normals .cloud2 .points_list (),
472+                     points_normals .p2 ,
473+                     points_normals .cloud1 .normals_list (),
474+                     points_normals .n1 ,
475+                     points_normals .cloud2 .normals_list (),
476+                     points_normals .n2 ,
477+                     points_normals .p1_lengths ,
478+                     points_normals .p2_lengths ,
479+                 )
480+             else :
481+                 self .assertClose (cham_cloud , cham_tensor )
482+                 self .assertClose (norm_cloud , norm_tensor )
483+                 self ._check_gradients (
484+                     cham_tensor ,
485+                     norm_tensor ,
486+                     cham_cloud ,
487+                     norm_cloud ,
488+                     points_normals .cloud1 .points_list (),
489+                     points_normals .p1 ,
490+                     points_normals .cloud2 .points_list (),
491+                     points_normals .p2 ,
492+                     points_normals .cloud1 .normals_list (),
493+                     points_normals .n1 ,
494+                     points_normals .cloud2 .normals_list (),
495+                     points_normals .n2 ,
496+                     points_normals .p1_lengths ,
497+                     points_normals .p2_lengths ,
498+                 )
471499
472500    def  test_chamfer_pointcloud_object_nonormals (self ):
473501        N  =  5 
@@ -481,9 +509,9 @@ def test_chamfer_pointcloud_object_nonormals(self):
481509            ("mean" , "mean" ),
482510            ("sum" , None ),
483511            ("mean" , None ),
512+             (None , None ),
484513        ]
485-         for  (point_reduction , batch_reduction ) in  reductions :
486- 
514+         for  point_reduction , batch_reduction  in  reductions :
487515            # Reinitialize all the tensors so that the 
488516            # backward pass can be computed. 
489517            points_normals  =  TestChamfer .init_pointclouds (
@@ -508,19 +536,38 @@ def test_chamfer_pointcloud_object_nonormals(self):
508536                batch_reduction = batch_reduction ,
509537            )
510538
511-             self .assertClose (cham_cloud , cham_tensor )
512-             self ._check_gradients (
513-                 cham_tensor ,
514-                 None ,
515-                 cham_cloud ,
516-                 None ,
517-                 points_normals .cloud1 .points_list (),
518-                 points_normals .p1 ,
519-                 points_normals .cloud2 .points_list (),
520-                 points_normals .p2 ,
521-                 lengths1 = points_normals .p1_lengths ,
522-                 lengths2 = points_normals .p2_lengths ,
523-             )
539+             if  point_reduction  is  None :
540+                 cham_tensor_bidirectional  =  torch .hstack (
541+                     [cham_tensor [0 ], cham_tensor [1 ]]
542+                 )
543+                 cham_cloud_bidirectional  =  torch .hstack ([cham_cloud [0 ], cham_cloud [1 ]])
544+                 self .assertClose (cham_cloud_bidirectional , cham_tensor_bidirectional )
545+                 self ._check_gradients (
546+                     cham_tensor_bidirectional ,
547+                     None ,
548+                     cham_cloud_bidirectional ,
549+                     None ,
550+                     points_normals .cloud1 .points_list (),
551+                     points_normals .p1 ,
552+                     points_normals .cloud2 .points_list (),
553+                     points_normals .p2 ,
554+                     lengths1 = points_normals .p1_lengths ,
555+                     lengths2 = points_normals .p2_lengths ,
556+                 )
557+             else :
558+                 self .assertClose (cham_cloud , cham_tensor )
559+                 self ._check_gradients (
560+                     cham_tensor ,
561+                     None ,
562+                     cham_cloud ,
563+                     None ,
564+                     points_normals .cloud1 .points_list (),
565+                     points_normals .p1 ,
566+                     points_normals .cloud2 .points_list (),
567+                     points_normals .p2 ,
568+                     lengths1 = points_normals .p1_lengths ,
569+                     lengths2 = points_normals .p2_lengths ,
570+                 )
524571
525572    def  test_chamfer_point_reduction_mean (self ):
526573        """ 
@@ -707,6 +754,99 @@ def test_single_directional_chamfer_point_reduction_sum(self):
707754            loss , loss_norm , pred_loss_sum , pred_loss_norm_sum , p1 , p11 , p2 , p22 
708755        )
709756
757+     def  test_chamfer_point_reduction_none (self ):
758+         """ 
759+         Compare output of vectorized chamfer loss with naive implementation 
760+         for point_reduction = None and batch_reduction = None. 
761+         """ 
762+         N , max_P1 , max_P2  =  7 , 10 , 18 
763+         device  =  get_random_cuda_device ()
764+         points_normals  =  TestChamfer .init_pointclouds (N , max_P1 , max_P2 , device )
765+         p1  =  points_normals .p1 
766+         p2  =  points_normals .p2 
767+         p1_normals  =  points_normals .n1 
768+         p2_normals  =  points_normals .n2 
769+         p11  =  p1 .detach ().clone ()
770+         p22  =  p2 .detach ().clone ()
771+         p11 .requires_grad  =  True 
772+         p22 .requires_grad  =  True 
773+ 
774+         pred_loss , pred_loss_norm  =  TestChamfer .chamfer_distance_naive (
775+             p1 , p2 , x_normals = p1_normals , y_normals = p2_normals 
776+         )
777+ 
778+         # point_reduction = None 
779+         loss , loss_norm  =  chamfer_distance (
780+             p11 ,
781+             p22 ,
782+             x_normals = p1_normals ,
783+             y_normals = p2_normals ,
784+             batch_reduction = None ,
785+             point_reduction = None ,
786+         )
787+ 
788+         loss_bidirectional  =  torch .hstack ([loss [0 ], loss [1 ]])
789+         pred_loss_bidirectional  =  torch .hstack ([pred_loss [0 ], pred_loss [1 ]])
790+         loss_norm_bidirectional  =  torch .hstack ([loss_norm [0 ], loss_norm [1 ]])
791+         pred_loss_norm_bidirectional  =  torch .hstack (
792+             [pred_loss_norm [0 ], pred_loss_norm [1 ]]
793+         )
794+ 
795+         self .assertClose (loss_bidirectional , pred_loss_bidirectional )
796+         self .assertClose (loss_norm_bidirectional , pred_loss_norm_bidirectional )
797+ 
798+         # Check gradients 
799+         self ._check_gradients (
800+             loss_bidirectional ,
801+             loss_norm_bidirectional ,
802+             pred_loss_bidirectional ,
803+             pred_loss_norm_bidirectional ,
804+             p1 ,
805+             p11 ,
806+             p2 ,
807+             p22 ,
808+         )
809+ 
810+     def  test_single_direction_chamfer_point_reduction_none (self ):
811+         """ 
812+         Compare output of vectorized chamfer loss with naive implementation 
813+         for point_reduction = None and batch_reduction = None. 
814+         """ 
815+         N , max_P1 , max_P2  =  7 , 10 , 18 
816+         device  =  get_random_cuda_device ()
817+         points_normals  =  TestChamfer .init_pointclouds (N , max_P1 , max_P2 , device )
818+         p1  =  points_normals .p1 
819+         p2  =  points_normals .p2 
820+         p1_normals  =  points_normals .n1 
821+         p2_normals  =  points_normals .n2 
822+         p11  =  p1 .detach ().clone ()
823+         p22  =  p2 .detach ().clone ()
824+         p11 .requires_grad  =  True 
825+         p22 .requires_grad  =  True 
826+ 
827+         pred_loss , pred_loss_norm  =  TestChamfer .chamfer_distance_naive (
828+             p1 , p2 , x_normals = p1_normals , y_normals = p2_normals 
829+         )
830+ 
831+         # point_reduction = None 
832+         loss , loss_norm  =  chamfer_distance (
833+             p11 ,
834+             p22 ,
835+             x_normals = p1_normals ,
836+             y_normals = p2_normals ,
837+             batch_reduction = None ,
838+             point_reduction = None ,
839+             single_directional = True ,
840+         )
841+ 
842+         self .assertClose (loss , pred_loss [0 ])
843+         self .assertClose (loss_norm , pred_loss_norm [0 ])
844+ 
845+         # Check gradients 
846+         self ._check_gradients (
847+             loss , loss_norm , pred_loss [0 ], pred_loss_norm [0 ], p1 , p11 , p2 , p22 
848+         )
849+ 
710850    def  _check_gradients (
711851        self ,
712852        loss ,
@@ -880,9 +1020,9 @@ def test_chamfer_joint_reduction(self):
8801020        with  self .assertRaisesRegex (ValueError , "batch_reduction must be one of" ):
8811021            chamfer_distance (p1 , p2 , weights = weights , batch_reduction = "max" )
8821022
883-         # Error when point_reduction is not in ["mean", "sum"]. 
1023+         # Error when point_reduction is not in ["mean", "sum"] or None . 
8841024        with  self .assertRaisesRegex (ValueError , "point_reduction must be one of" ):
885-             chamfer_distance (p1 , p2 , weights = weights , point_reduction = None )
1025+             chamfer_distance (p1 , p2 , weights = weights , point_reduction = "max" )
8861026
8871027    def  test_incorrect_weights (self ):
8881028        N , P1 , P2  =  16 , 64 , 128 
0 commit comments