1313#include < thrust/scan.h>
1414#include < cstdio>
1515#include " marching_cubes/tables.h"
16- #include " utils/pytorch3d_cutils.h"
1716
1817/*
1918Parallelized marching cubes for pytorch extension
@@ -267,13 +266,12 @@ __global__ void CompactVoxelsKernel(
267266// isolevel: threshold to determine isosurface intersection
268267//
269268__global__ void GenerateFacesKernel (
270- torch ::PackedTensorAccessor32<float , 2 , torch ::RestrictPtrTraits> verts,
271- torch ::PackedTensorAccessor<int64_t , 2 , torch ::RestrictPtrTraits> faces,
272- torch ::PackedTensorAccessor<int64_t , 1 , torch ::RestrictPtrTraits> ids,
273- torch ::PackedTensorAccessor32<int , 1 , torch ::RestrictPtrTraits>
269+ at ::PackedTensorAccessor32<float , 2 , at ::RestrictPtrTraits> verts,
270+ at ::PackedTensorAccessor<int64_t , 2 , at ::RestrictPtrTraits> faces,
271+ at ::PackedTensorAccessor<int64_t , 1 , at ::RestrictPtrTraits> ids,
272+ at ::PackedTensorAccessor32<int , 1 , at ::RestrictPtrTraits>
274273 compactedVoxelArray,
275- torch::PackedTensorAccessor32<int , 1 , torch::RestrictPtrTraits>
276- numVertsScanned,
274+ at::PackedTensorAccessor32<int , 1 , at::RestrictPtrTraits> numVertsScanned,
277275 const uint activeVoxels,
278276 const at::PackedTensorAccessor32<float , 3 , at::RestrictPtrTraits> vol,
279277 const at::PackedTensorAccessor32<int , 2 , at::RestrictPtrTraits> faceTable,
@@ -436,15 +434,15 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> MarchingCubesCuda(
436434 cudaStream_t stream = at::cuda::getCurrentCUDAStream ();
437435
438436 // transfer _FACE_TABLE data to device
439- torch ::Tensor face_table_tensor = torch ::zeros (
440- {256 , 16 }, torch ::TensorOptions ().dtype (at::kInt ).device (at::kCPU ));
437+ at ::Tensor face_table_tensor = at ::zeros (
438+ {256 , 16 }, at ::TensorOptions ().dtype (at::kInt ).device (at::kCPU ));
441439 auto face_table_a = face_table_tensor.accessor <int , 2 >();
442440 for (int i = 0 ; i < 256 ; i++) {
443441 for (int j = 0 ; j < 16 ; j++) {
444442 face_table_a[i][j] = _FACE_TABLE[i][j];
445443 }
446444 }
447- torch ::Tensor faceTable = face_table_tensor.to (vol.device ());
445+ at ::Tensor faceTable = face_table_tensor.to (vol.device ());
448446
449447 // get numVoxels
450448 int threads = 128 ;
@@ -458,10 +456,10 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> MarchingCubesCuda(
458456 }
459457
460458 auto d_voxelVerts =
461- torch ::zeros ({numVoxels}, torch ::TensorOptions ().dtype (at::kInt ))
459+ at ::zeros ({numVoxels}, at ::TensorOptions ().dtype (at::kInt ))
462460 .to (vol.device ());
463461 auto d_voxelOccupied =
464- torch ::zeros ({numVoxels}, torch ::TensorOptions ().dtype (at::kInt ))
462+ at ::zeros ({numVoxels}, at ::TensorOptions ().dtype (at::kInt ))
465463 .to (vol.device ());
466464
467465 // Execute "ClassifyVoxelKernel" kernel to precompute
@@ -480,7 +478,7 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> MarchingCubesCuda(
480478 // If the number of active voxels is 0, return zero tensor for verts and
481479 // faces.
482480 auto d_voxelOccupiedScan =
483- torch ::zeros ({numVoxels}, torch ::TensorOptions ().dtype (at::kInt ))
481+ at ::zeros ({numVoxels}, at ::TensorOptions ().dtype (at::kInt ))
484482 .to (vol.device ());
485483 ThrustScanWrapper (
486484 d_voxelOccupiedScan.data_ptr <int >(),
@@ -493,23 +491,21 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> MarchingCubesCuda(
493491 int activeVoxels = lastElement + lastScan;
494492
495493 const int device_id = vol.device ().index ();
496- auto opt =
497- torch::TensorOptions ().dtype (torch::kInt ).device (torch::kCUDA , device_id);
498- auto opt_long = torch::TensorOptions ()
499- .dtype (torch::kInt64 )
500- .device (torch::kCUDA , device_id);
494+ auto opt = at::TensorOptions ().dtype (at::kInt ).device (at::kCUDA , device_id);
495+ auto opt_long =
496+ at::TensorOptions ().dtype (at::kLong ).device (at::kCUDA , device_id);
501497
502498 if (activeVoxels == 0 ) {
503499 int ntris = 0 ;
504- torch ::Tensor verts = torch ::zeros ({ntris * 3 , 3 }, vol.options ());
505- torch ::Tensor faces = torch ::zeros ({ntris, 3 }, opt_long);
506- torch ::Tensor ids = torch ::zeros ({ntris}, opt_long);
500+ at ::Tensor verts = at ::zeros ({ntris * 3 , 3 }, vol.options ());
501+ at ::Tensor faces = at ::zeros ({ntris, 3 }, opt_long);
502+ at ::Tensor ids = at ::zeros ({ntris}, opt_long);
507503 return std::make_tuple (verts, faces, ids);
508504 }
509505
510506 // Execute "CompactVoxelsKernel" kernel to compress voxels for accleration.
511507 // This allows us to run triangle generation on only the occupied voxels.
512- auto d_compVoxelArray = torch ::zeros ({activeVoxels}, opt);
508+ auto d_compVoxelArray = at ::zeros ({activeVoxels}, opt);
513509 CompactVoxelsKernel<<<grid, threads, 0 , stream>>> (
514510 d_compVoxelArray.packed_accessor32 <int , 1 , at::RestrictPtrTraits>(),
515511 d_voxelOccupied.packed_accessor32 <int , 1 , at::RestrictPtrTraits>(),
@@ -519,7 +515,7 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> MarchingCubesCuda(
519515 cudaDeviceSynchronize ();
520516
521517 // Scan d_voxelVerts array to generate offsets of vertices for each voxel
522- auto d_voxelVertsScan = torch ::zeros ({numVoxels}, opt);
518+ auto d_voxelVertsScan = at ::zeros ({numVoxels}, opt);
523519 ThrustScanWrapper (
524520 d_voxelVertsScan.data_ptr <int >(),
525521 d_voxelVerts.data_ptr <int >(),
@@ -533,10 +529,10 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> MarchingCubesCuda(
533529 // Execute "GenerateFacesKernel" kernel
534530 // This runs only on the occupied voxels.
535531 // It looks up the field values and generates the triangle data.
536- torch ::Tensor verts = torch ::zeros ({totalVerts, 3 }, vol.options ());
537- torch ::Tensor faces = torch ::zeros ({totalVerts / 3 , 3 }, opt_long);
532+ at ::Tensor verts = at ::zeros ({totalVerts, 3 }, vol.options ());
533+ at ::Tensor faces = at ::zeros ({totalVerts / 3 , 3 }, opt_long);
538534
539- torch ::Tensor ids = torch ::zeros ({totalVerts}, opt_long);
535+ at ::Tensor ids = at ::zeros ({totalVerts}, opt_long);
540536
541537 dim3 grid2 ((activeVoxels + threads - 1 ) / threads, 1 , 1 );
542538 if (grid2.x > 65535 ) {
0 commit comments