| 
5 | 5 | # LICENSE file in the root directory of this source tree.  | 
6 | 6 | 
 
  | 
7 | 7 | from collections import defaultdict  | 
8 |  | -from dataclasses import dataclass, field, fields  | 
 | 8 | +from dataclasses import dataclass  | 
9 | 9 | from typing import (  | 
10 |  | -    Any,  | 
11 | 10 |     ClassVar,  | 
12 | 11 |     Dict,  | 
13 | 12 |     Iterable,  | 
14 | 13 |     Iterator,  | 
15 | 14 |     List,  | 
16 |  | -    Mapping,  | 
17 | 15 |     Optional,  | 
18 | 16 |     Sequence,  | 
19 | 17 |     Tuple,  | 
20 | 18 |     Type,  | 
21 |  | -    Union,  | 
22 | 19 | )  | 
23 | 20 | 
 
  | 
24 |  | -import numpy as np  | 
25 | 21 | import torch  | 
26 |  | -from pytorch3d.renderer.camera_utils import join_cameras_as_batch  | 
27 |  | -from pytorch3d.renderer.cameras import CamerasBase, PerspectiveCameras  | 
28 |  | -from pytorch3d.structures.pointclouds import join_pointclouds_as_batch, Pointclouds  | 
29 | 22 | 
 
  | 
30 |  | - | 
31 |  | -@dataclass  | 
32 |  | -class FrameData(Mapping[str, Any]):  | 
33 |  | -    """  | 
34 |  | -    A type of the elements returned by indexing the dataset object.  | 
35 |  | -    It can represent both individual frames and batches of thereof;  | 
36 |  | -    in this documentation, the sizes of tensors refer to single frames;  | 
37 |  | -    add the first batch dimension for the collation result.  | 
38 |  | -
  | 
39 |  | -    Args:  | 
40 |  | -        frame_number: The number of the frame within its sequence.  | 
41 |  | -            0-based continuous integers.  | 
42 |  | -        sequence_name: The unique name of the frame's sequence.  | 
43 |  | -        sequence_category: The object category of the sequence.  | 
44 |  | -        frame_timestamp: The time elapsed since the start of a sequence in sec.  | 
45 |  | -        image_size_hw: The size of the image in pixels; (height, width) tensor  | 
46 |  | -                        of shape (2,).  | 
47 |  | -        image_path: The qualified path to the loaded image (with dataset_root).  | 
48 |  | -        image_rgb: A Tensor of shape `(3, H, W)` holding the RGB image  | 
49 |  | -            of the frame; elements are floats in [0, 1].  | 
50 |  | -        mask_crop: A binary mask of shape `(1, H, W)` denoting the valid image  | 
51 |  | -            regions. Regions can be invalid (mask_crop[i,j]=0) in case they  | 
52 |  | -            are a result of zero-padding of the image after cropping around  | 
53 |  | -            the object bounding box; elements are floats in {0.0, 1.0}.  | 
54 |  | -        depth_path: The qualified path to the frame's depth map.  | 
55 |  | -        depth_map: A float Tensor of shape `(1, H, W)` holding the depth map  | 
56 |  | -            of the frame; values correspond to distances from the camera;  | 
57 |  | -            use `depth_mask` and `mask_crop` to filter for valid pixels.  | 
58 |  | -        depth_mask: A binary mask of shape `(1, H, W)` denoting pixels of the  | 
59 |  | -            depth map that are valid for evaluation, they have been checked for  | 
60 |  | -            consistency across views; elements are floats in {0.0, 1.0}.  | 
61 |  | -        mask_path: A qualified path to the foreground probability mask.  | 
62 |  | -        fg_probability: A Tensor of `(1, H, W)` denoting the probability of the  | 
63 |  | -            pixels belonging to the captured object; elements are floats  | 
64 |  | -            in [0, 1].  | 
65 |  | -        bbox_xywh: The bounding box tightly enclosing the foreground object in the  | 
66 |  | -            format (x0, y0, width, height). The convention assumes that  | 
67 |  | -            `x0+width` and `y0+height` includes the boundary of the box.  | 
68 |  | -            I.e., to slice out the corresponding crop from an image tensor `I`  | 
69 |  | -            we execute `crop = I[..., y0:y0+height, x0:x0+width]`  | 
70 |  | -        crop_bbox_xywh: The bounding box denoting the boundaries of `image_rgb`  | 
71 |  | -            in the original image coordinates in the format (x0, y0, width, height).  | 
72 |  | -            The convention is the same as for `bbox_xywh`. `crop_bbox_xywh` differs  | 
73 |  | -            from `bbox_xywh` due to padding (which can happen e.g. due to  | 
74 |  | -            setting `JsonIndexDataset.box_crop_context > 0`)  | 
75 |  | -        camera: A PyTorch3D camera object corresponding the frame's viewpoint,  | 
76 |  | -            corrected for cropping if it happened.  | 
77 |  | -        camera_quality_score: The score proportional to the confidence of the  | 
78 |  | -            frame's camera estimation (the higher the more accurate).  | 
79 |  | -        point_cloud_quality_score: The score proportional to the accuracy of the  | 
80 |  | -            frame's sequence point cloud (the higher the more accurate).  | 
81 |  | -        sequence_point_cloud_path: The path to the sequence's point cloud.  | 
82 |  | -        sequence_point_cloud: A PyTorch3D Pointclouds object holding the  | 
83 |  | -            point cloud corresponding to the frame's sequence. When the object  | 
84 |  | -            represents a batch of frames, point clouds may be deduplicated;  | 
85 |  | -            see `sequence_point_cloud_idx`.  | 
86 |  | -        sequence_point_cloud_idx: Integer indices mapping frame indices to the  | 
87 |  | -            corresponding point clouds in `sequence_point_cloud`; to get the  | 
88 |  | -            corresponding point cloud to `image_rgb[i]`, use  | 
89 |  | -            `sequence_point_cloud[sequence_point_cloud_idx[i]]`.  | 
90 |  | -        frame_type: The type of the loaded frame specified in  | 
91 |  | -            `subset_lists_file`, if provided.  | 
92 |  | -        meta: A dict for storing additional frame information.  | 
93 |  | -    """  | 
94 |  | - | 
95 |  | -    frame_number: Optional[torch.LongTensor]  | 
96 |  | -    sequence_name: Union[str, List[str]]  | 
97 |  | -    sequence_category: Union[str, List[str]]  | 
98 |  | -    frame_timestamp: Optional[torch.Tensor] = None  | 
99 |  | -    image_size_hw: Optional[torch.Tensor] = None  | 
100 |  | -    image_path: Union[str, List[str], None] = None  | 
101 |  | -    image_rgb: Optional[torch.Tensor] = None  | 
102 |  | -    # masks out padding added due to cropping the square bit  | 
103 |  | -    mask_crop: Optional[torch.Tensor] = None  | 
104 |  | -    depth_path: Union[str, List[str], None] = None  | 
105 |  | -    depth_map: Optional[torch.Tensor] = None  | 
106 |  | -    depth_mask: Optional[torch.Tensor] = None  | 
107 |  | -    mask_path: Union[str, List[str], None] = None  | 
108 |  | -    fg_probability: Optional[torch.Tensor] = None  | 
109 |  | -    bbox_xywh: Optional[torch.Tensor] = None  | 
110 |  | -    crop_bbox_xywh: Optional[torch.Tensor] = None  | 
111 |  | -    camera: Optional[PerspectiveCameras] = None  | 
112 |  | -    camera_quality_score: Optional[torch.Tensor] = None  | 
113 |  | -    point_cloud_quality_score: Optional[torch.Tensor] = None  | 
114 |  | -    sequence_point_cloud_path: Union[str, List[str], None] = None  | 
115 |  | -    sequence_point_cloud: Optional[Pointclouds] = None  | 
116 |  | -    sequence_point_cloud_idx: Optional[torch.Tensor] = None  | 
117 |  | -    frame_type: Union[str, List[str], None] = None  # known | unseen  | 
118 |  | -    meta: dict = field(default_factory=lambda: {})  | 
119 |  | - | 
120 |  | -    def to(self, *args, **kwargs):  | 
121 |  | -        new_params = {}  | 
122 |  | -        for f in fields(self):  | 
123 |  | -            value = getattr(self, f.name)  | 
124 |  | -            if isinstance(value, (torch.Tensor, Pointclouds, CamerasBase)):  | 
125 |  | -                new_params[f.name] = value.to(*args, **kwargs)  | 
126 |  | -            else:  | 
127 |  | -                new_params[f.name] = value  | 
128 |  | -        return type(self)(**new_params)  | 
129 |  | - | 
130 |  | -    def cpu(self):  | 
131 |  | -        return self.to(device=torch.device("cpu"))  | 
132 |  | - | 
133 |  | -    def cuda(self):  | 
134 |  | -        return self.to(device=torch.device("cuda"))  | 
135 |  | - | 
136 |  | -    # the following functions make sure **frame_data can be passed to functions  | 
137 |  | -    def __iter__(self):  | 
138 |  | -        for f in fields(self):  | 
139 |  | -            yield f.name  | 
140 |  | - | 
141 |  | -    def __getitem__(self, key):  | 
142 |  | -        return getattr(self, key)  | 
143 |  | - | 
144 |  | -    def __len__(self):  | 
145 |  | -        return len(fields(self))  | 
146 |  | - | 
147 |  | -    @classmethod  | 
148 |  | -    def collate(cls, batch):  | 
149 |  | -        """  | 
150 |  | -        Given a list objects `batch` of class `cls`, collates them into a batched  | 
151 |  | -        representation suitable for processing with deep networks.  | 
152 |  | -        """  | 
153 |  | - | 
154 |  | -        elem = batch[0]  | 
155 |  | - | 
156 |  | -        if isinstance(elem, cls):  | 
157 |  | -            pointcloud_ids = [id(el.sequence_point_cloud) for el in batch]  | 
158 |  | -            id_to_idx = defaultdict(list)  | 
159 |  | -            for i, pc_id in enumerate(pointcloud_ids):  | 
160 |  | -                id_to_idx[pc_id].append(i)  | 
161 |  | - | 
162 |  | -            sequence_point_cloud = []  | 
163 |  | -            sequence_point_cloud_idx = -np.ones((len(batch),))  | 
164 |  | -            for i, ind in enumerate(id_to_idx.values()):  | 
165 |  | -                sequence_point_cloud_idx[ind] = i  | 
166 |  | -                sequence_point_cloud.append(batch[ind[0]].sequence_point_cloud)  | 
167 |  | -            assert (sequence_point_cloud_idx >= 0).all()  | 
168 |  | - | 
169 |  | -            override_fields = {  | 
170 |  | -                "sequence_point_cloud": sequence_point_cloud,  | 
171 |  | -                "sequence_point_cloud_idx": sequence_point_cloud_idx.tolist(),  | 
172 |  | -            }  | 
173 |  | -            # note that the pre-collate value of sequence_point_cloud_idx is unused  | 
174 |  | - | 
175 |  | -            collated = {}  | 
176 |  | -            for f in fields(elem):  | 
177 |  | -                list_values = override_fields.get(  | 
178 |  | -                    f.name, [getattr(d, f.name) for d in batch]  | 
179 |  | -                )  | 
180 |  | -                collated[f.name] = (  | 
181 |  | -                    cls.collate(list_values)  | 
182 |  | -                    if all(list_value is not None for list_value in list_values)  | 
183 |  | -                    else None  | 
184 |  | -                )  | 
185 |  | -            return cls(**collated)  | 
186 |  | - | 
187 |  | -        elif isinstance(elem, Pointclouds):  | 
188 |  | -            return join_pointclouds_as_batch(batch)  | 
189 |  | - | 
190 |  | -        elif isinstance(elem, CamerasBase):  | 
191 |  | -            # TODO: don't store K; enforce working in NDC space  | 
192 |  | -            return join_cameras_as_batch(batch)  | 
193 |  | -        else:  | 
194 |  | -            return torch.utils.data._utils.collate.default_collate(batch)  | 
195 |  | - | 
196 |  | - | 
197 |  | -class _GenericWorkaround:  | 
198 |  | -    """  | 
199 |  | -    OmegaConf.structured has a weirdness when you try to apply  | 
200 |  | -    it to a dataclass whose first base class is a Generic which is not  | 
201 |  | -    Dict. The issue is with a function called get_dict_key_value_types  | 
202 |  | -    in omegaconf/_utils.py.  | 
203 |  | -    For example this fails:  | 
204 |  | -
  | 
205 |  | -        @dataclass(eq=False)  | 
206 |  | -        class D(torch.utils.data.Dataset[int]):  | 
207 |  | -            a: int = 3  | 
208 |  | -
  | 
209 |  | -        OmegaConf.structured(D)  | 
210 |  | -
  | 
211 |  | -    We avoid the problem by adding this class as an extra base class.  | 
212 |  | -    """  | 
213 |  | - | 
214 |  | -    pass  | 
 | 23 | +from pytorch3d.implicitron.dataset.frame_data import FrameData  | 
 | 24 | +from pytorch3d.implicitron.dataset.utils import GenericWorkaround  | 
215 | 25 | 
 
  | 
216 | 26 | 
 
  | 
217 | 27 | @dataclass(eq=False)  | 
218 |  | -class DatasetBase(_GenericWorkaround, torch.utils.data.Dataset[FrameData]):  | 
 | 28 | +class DatasetBase(GenericWorkaround, torch.utils.data.Dataset[FrameData]):  | 
219 | 29 |     """  | 
220 | 30 |     Base class to describe a dataset to be used with Implicitron.  | 
221 | 31 | 
  | 
 | 
0 commit comments