From 6eae22ed55dcc1da6603e1425ed3636e2c0fba12 Mon Sep 17 00:00:00 2001 From: Kyle Vedder Date: Thu, 13 Jun 2024 18:58:05 -0400 Subject: [PATCH] Fixed last dimension size check so that it doesn't trivially pass. Currently, it checks that the `2`th dimension of `p2` is the same size as the `2`th dimension of `p2` instead of `p1`. --- pytorch3d/csrc/knn/knn.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch3d/csrc/knn/knn.cu b/pytorch3d/csrc/knn/knn.cu index 93a3060b2..ad9dce247 100644 --- a/pytorch3d/csrc/knn/knn.cu +++ b/pytorch3d/csrc/knn/knn.cu @@ -338,7 +338,7 @@ std::tuple KNearestNeighborIdxCuda( TORCH_CHECK((norm == 1) || (norm == 2), "Norm must be 1 or 2."); - TORCH_CHECK(p2.size(2) == D, "Point sets must have the same last dimension"); + TORCH_CHECK(p1.size(2) == D, "Point sets must have the same last dimension"); auto long_dtype = lengths1.options().dtype(at::kLong); auto idxs = at::zeros({N, P1, K}, long_dtype); auto dists = at::zeros({N, P1, K}, p1.options());