From 656474679a4682203deb56e1e6299c0b1f43b84f Mon Sep 17 00:00:00 2001 From: rlyu Date: Thu, 22 Feb 2024 13:05:51 -0800 Subject: [PATCH 1/2] Optimize list_to_packed to avoid for loop --- pytorch3d/structures/utils.py | 22 ++++++++-------------- 1 file changed, 8 insertions(+), 14 deletions(-) diff --git a/pytorch3d/structures/utils.py b/pytorch3d/structures/utils.py index aab4fc3da..4d45c551d 100644 --- a/pytorch3d/structures/utils.py +++ b/pytorch3d/structures/utils.py @@ -133,22 +133,16 @@ def list_to_packed(x: List[torch.Tensor]): - **item_packed_to_list_idx**: tensor of shape sum(Mi) containing the index of the element in the list the item belongs to. """ + device = x[0].device N = len(x) - num_items = torch.zeros(N, dtype=torch.int64, device=x[0].device) - item_packed_first_idx = torch.zeros(N, dtype=torch.int64, device=x[0].device) - item_packed_to_list_idx = [] - cur = 0 - for i, y in enumerate(x): - num = len(y) - num_items[i] = num - item_packed_first_idx[i] = cur - item_packed_to_list_idx.append( - torch.full((num,), i, dtype=torch.int64, device=y.device) - ) - cur += num - + Mi = x[0].shape[0] + num_items = torch.full((N,), Mi, dtype=torch.int64).to(device) + item_packed_first_idx = torch.zeros_like(num_items) + item_packed_first_idx[1:] = torch.cumsum(num_items[:-1], dim=0) + total_items = N * Mi + item_packed_to_list_idx = torch.arange(total_items, dtype=torch.int64).to(device) + item_packed_to_list_idx = torch.bucketize(item_packed_to_list_idx, item_packed_first_idx, right=True) - 1 x_packed = torch.cat(x, dim=0) - item_packed_to_list_idx = torch.cat(item_packed_to_list_idx, dim=0) return x_packed, num_items, item_packed_first_idx, item_packed_to_list_idx From 21b9f11fefb00ea674d2d81c6dee5303b04aaa33 Mon Sep 17 00:00:00 2001 From: rlyu Date: Fri, 23 Feb 2024 17:07:23 -0800 Subject: [PATCH 2/2] Make sure it works for different lengths --- pytorch3d/structures/utils.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/pytorch3d/structures/utils.py b/pytorch3d/structures/utils.py index 4d45c551d..a2ca21c47 100644 --- a/pytorch3d/structures/utils.py +++ b/pytorch3d/structures/utils.py @@ -133,14 +133,14 @@ def list_to_packed(x: List[torch.Tensor]): - **item_packed_to_list_idx**: tensor of shape sum(Mi) containing the index of the element in the list the item belongs to. """ + if not x: + raise ValueError("Input list is empty") device = x[0].device - N = len(x) - Mi = x[0].shape[0] - num_items = torch.full((N,), Mi, dtype=torch.int64).to(device) + sizes = [xi.shape[0] for xi in x] + num_items = torch.tensor(sizes, dtype=torch.int64).to(device) item_packed_first_idx = torch.zeros_like(num_items) item_packed_first_idx[1:] = torch.cumsum(num_items[:-1], dim=0) - total_items = N * Mi - item_packed_to_list_idx = torch.arange(total_items, dtype=torch.int64).to(device) + item_packed_to_list_idx = torch.arange(torch.sum(num_items), dtype=torch.int64).to(device) item_packed_to_list_idx = torch.bucketize(item_packed_to_list_idx, item_packed_first_idx, right=True) - 1 x_packed = torch.cat(x, dim=0)