1414import sys
1515import warnings
1616from collections import namedtuple
17+ from dataclasses import asdict , dataclass
1718from io import BytesIO , TextIOBase
18- from typing import List , Optional , Tuple , cast
19+ from typing import List , Optional , Tuple
1920
2021import numpy as np
2122import torch
@@ -137,6 +138,7 @@ def __init__(self, f) -> None:
137138 self.ascii: (bool) Whether in ascii format
138139 self.big_endian: (bool) (if not ascii) whether big endian
139140 self.obj_info: (List[str]) arbitrary extra data
141+ self.comments: (List[str]) comments
140142
141143 Args:
142144 f: file-like object.
@@ -145,7 +147,8 @@ def __init__(self, f) -> None:
145147 raise ValueError ("Invalid file header." )
146148 seen_format = False
147149 self .elements : List [_PlyElementType ] = []
148- self .obj_info = []
150+ self .comments : List [str ] = []
151+ self .obj_info : List [str ] = []
149152 while True :
150153 line = f .readline ()
151154 if isinstance (line , bytes ):
@@ -176,6 +179,9 @@ def __init__(self, f) -> None:
176179 continue
177180 if line .startswith ("format" ):
178181 raise ValueError ("Invalid format line." )
182+ if line .startswith ("comment " ):
183+ self .comments .append (line [8 :])
184+ continue
179185 if line .startswith ("comment" ) or len (line ) == 0 :
180186 continue
181187 if line .startswith ("element" ):
@@ -781,9 +787,28 @@ def _load_ply_raw(f, path_manager: PathManager) -> Tuple[_PlyHeader, dict]:
781787 return header , elements
782788
783789
790+ @dataclass (frozen = True )
791+ class _VertsColumnIndices :
792+ """
793+ Contains the relevant layout of the verts section of file being read.
794+ Members
795+ point_idxs: List[int] of 3 point columns.
796+ color_idxs: List[int] of 3 color columns if they are present,
797+ otherwise None.
798+ color_scale: value to scale colors by.
799+ normal_idxs: List[int] of 3 normals columns if they are present,
800+ otherwise None.
801+ """
802+
803+ point_idxs : List [int ]
804+ color_idxs : Optional [List [int ]]
805+ color_scale : float
806+ normal_idxs : Optional [List [int ]]
807+
808+
784809def _get_verts_column_indices (
785810 vertex_head : _PlyElementType ,
786- ) -> Tuple [ List [ int ], Optional [ List [ int ]], float , Optional [ List [ int ]]] :
811+ ) -> _VertsColumnIndices :
787812 """
788813 Get the columns of verts, verts_colors, and verts_normals in the vertex
789814 element of a parsed ply file, together with a color scale factor.
@@ -809,12 +834,7 @@ def _get_verts_column_indices(
809834 vertex_head: as returned from load_ply_raw.
810835
811836 Returns:
812- point_idxs: List[int] of 3 point columns.
813- color_idxs: List[int] of 3 color columns if they are present,
814- otherwise None.
815- color_scale: value to scale colors by.
816- normal_idxs: List[int] of 3 normals columns if they are present,
817- otherwise None.
837+ _VertsColumnIndices object
818838 """
819839 point_idxs : List [Optional [int ]] = [None , None , None ]
820840 color_idxs : List [Optional [int ]] = [None , None , None ]
@@ -839,29 +859,38 @@ def _get_verts_column_indices(
839859 for idx in color_idxs
840860 ):
841861 color_scale = 1.0 / 255
842- return (
843- point_idxs ,
844- # pyre-fixme[22]: The cast is redundant.
845- None if None in color_idxs else cast (List [int ], color_idxs ),
846- color_scale ,
847- # pyre-fixme[22]: The cast is redundant.
848- None if None in normal_idxs else cast (List [int ], normal_idxs ),
862+ return _VertsColumnIndices (
863+ point_idxs = point_idxs ,
864+ color_idxs = None if None in color_idxs else color_idxs ,
865+ color_scale = color_scale ,
866+ normal_idxs = None if None in normal_idxs else normal_idxs ,
849867 )
850868
851869
852- def _get_verts (
853- header : _PlyHeader , elements : dict
854- ) -> Tuple [torch .Tensor , Optional [torch .Tensor ], Optional [torch .Tensor ]]:
870+ @dataclass (frozen = True )
871+ class _VertsData :
872+ """
873+ Contains the data of the verts section of file being read.
874+ Members:
875+ verts: FloatTensor of shape (V, 3).
876+ verts_colors: None or FloatTensor of shape (V, 3).
877+ verts_normals: None or FloatTensor of shape (V, 3).
878+ """
879+
880+ verts : torch .Tensor
881+ verts_colors : Optional [torch .Tensor ] = None
882+ verts_normals : Optional [torch .Tensor ] = None
883+
884+
885+ def _get_verts (header : _PlyHeader , elements : dict ) -> _VertsData :
855886 """
856887 Get the vertex locations, colors and normals from a parsed ply file.
857888
858889 Args:
859890 header, elements: as returned from load_ply_raw.
860891
861892 Returns:
862- verts: FloatTensor of shape (V, 3).
863- vertex_colors: None or FloatTensor of shape (V, 3).
864- vertex_normals: None or FloatTensor of shape (V, 3).
893+ _VertsData object
865894 """
866895
867896 vertex = elements .get ("vertex" , None )
@@ -870,16 +899,17 @@ def _get_verts(
870899 if not isinstance (vertex , list ):
871900 raise ValueError ("Invalid vertices in file." )
872901 vertex_head = next (head for head in header .elements if head .name == "vertex" )
873- point_idxs , color_idxs , color_scale , normal_idxs = _get_verts_column_indices (
874- vertex_head
875- )
902+
903+ column_idxs = _get_verts_column_indices (vertex_head )
876904
877905 # Case of no vertices
878906 if vertex_head .count == 0 :
879907 verts = torch .zeros ((0 , 3 ), dtype = torch .float32 )
880- if color_idxs is None :
881- return verts , None , None
882- return verts , torch .zeros ((0 , 3 ), dtype = torch .float32 ), None
908+ if column_idxs .color_idxs is None :
909+ return _VertsData (verts = verts )
910+ return _VertsData (
911+ verts = verts , verts_colors = torch .zeros ((0 , 3 ), dtype = torch .float32 )
912+ )
883913
884914 # Simple case where the only data is the vertices themselves
885915 if (
@@ -888,22 +918,22 @@ def _get_verts(
888918 and vertex [0 ].ndim == 2
889919 and vertex [0 ].shape [1 ] == 3
890920 ):
891- return _make_tensor (vertex [0 ], cols = 3 , dtype = torch .float32 ), None , None
921+ return _VertsData ( verts = _make_tensor (vertex [0 ], cols = 3 , dtype = torch .float32 ))
892922
893923 vertex_colors = None
894924 vertex_normals = None
895925
896926 if len (vertex ) == 1 :
897927 # This is the case where the whole vertex element has one type,
898928 # so it was read as a single array and we can index straight into it.
899- verts = torch .tensor (vertex [0 ][:, point_idxs ], dtype = torch .float32 )
900- if color_idxs is not None :
901- vertex_colors = color_scale * torch .tensor (
902- vertex [0 ][:, color_idxs ], dtype = torch .float32
929+ verts = torch .tensor (vertex [0 ][:, column_idxs . point_idxs ], dtype = torch .float32 )
930+ if column_idxs . color_idxs is not None :
931+ vertex_colors = column_idxs . color_scale * torch .tensor (
932+ vertex [0 ][:, column_idxs . color_idxs ], dtype = torch .float32
903933 )
904- if normal_idxs is not None :
934+ if column_idxs . normal_idxs is not None :
905935 vertex_normals = torch .tensor (
906- vertex [0 ][:, normal_idxs ], dtype = torch .float32
936+ vertex [0 ][:, column_idxs . normal_idxs ], dtype = torch .float32
907937 )
908938 else :
909939 # The vertex element is heterogeneous. It was read as several arrays,
@@ -918,7 +948,7 @@ def _get_verts(
918948 ]
919949 verts = torch .empty (size = (vertex_head .count , 3 ), dtype = torch .float32 )
920950 for axis in range (3 ):
921- partnum , col = prop_to_partnum_col [point_idxs [axis ]]
951+ partnum , col = prop_to_partnum_col [column_idxs . point_idxs [axis ]]
922952 verts .numpy ()[:, axis ] = vertex [partnum ][:, col ]
923953 # Note that in the previous line, we made the assignment
924954 # as numpy arrays by casting verts. If we took the (more
@@ -928,30 +958,49 @@ def _get_verts(
928958 # if not vertex[partnum].flags["C_CONTIGUOUS"]:
929959 # vertex[partnum] = np.ascontiguousarray(vertex[partnum])
930960 # verts[:, axis] = torch.tensor((vertex[partnum][:, col]))
931- if color_idxs is not None :
961+ if column_idxs . color_idxs is not None :
932962 vertex_colors = torch .empty (
933963 size = (vertex_head .count , 3 ), dtype = torch .float32
934964 )
935965 for color in range (3 ):
936- partnum , col = prop_to_partnum_col [color_idxs [color ]]
966+ partnum , col = prop_to_partnum_col [column_idxs . color_idxs [color ]]
937967 vertex_colors .numpy ()[:, color ] = vertex [partnum ][:, col ]
938- vertex_colors *= color_scale
939- if normal_idxs is not None :
968+ vertex_colors *= column_idxs . color_scale
969+ if column_idxs . normal_idxs is not None :
940970 vertex_normals = torch .empty (
941971 size = (vertex_head .count , 3 ), dtype = torch .float32
942972 )
943973 for axis in range (3 ):
944- partnum , col = prop_to_partnum_col [normal_idxs [axis ]]
974+ partnum , col = prop_to_partnum_col [column_idxs . normal_idxs [axis ]]
945975 vertex_normals .numpy ()[:, axis ] = vertex [partnum ][:, col ]
946976
947- return verts , vertex_colors , vertex_normals
977+ return _VertsData (
978+ verts = verts ,
979+ verts_colors = vertex_colors ,
980+ verts_normals = vertex_normals ,
981+ )
982+
983+
984+ @dataclass (frozen = True )
985+ class _PlyData :
986+ """
987+ Contains the data from a PLY file which has been read.
988+ Members:
989+ header: _PlyHeader of file metadata from the header
990+ verts: FloatTensor of shape (V, 3).
991+ faces: None or LongTensor of vertex indices, shape (F, 3).
992+ verts_colors: None or FloatTensor of shape (V, 3).
993+ verts_normals: None or FloatTensor of shape (V, 3).
994+ """
995+
996+ header : _PlyHeader
997+ verts : torch .Tensor
998+ faces : Optional [torch .Tensor ]
999+ verts_colors : Optional [torch .Tensor ]
1000+ verts_normals : Optional [torch .Tensor ]
9481001
9491002
950- def _load_ply (
951- f , * , path_manager : PathManager
952- ) -> Tuple [
953- torch .Tensor , Optional [torch .Tensor ], Optional [torch .Tensor ], Optional [torch .Tensor ]
954- ]:
1003+ def _load_ply (f , * , path_manager : PathManager ) -> _PlyData :
9551004 """
9561005 Load the data from a .ply file.
9571006
@@ -964,14 +1013,11 @@ def _load_ply(
9641013 path_manager: PathManager for loading if f is a str.
9651014
9661015 Returns:
967- verts: FloatTensor of shape (V, 3).
968- faces: None or LongTensor of vertex indices, shape (F, 3).
969- vertex_colors: None or FloatTensor of shape (V, 3).
970- vertex_normals: None or FloatTensor of shape (V, 3).
1016+ _PlyData object
9711017 """
9721018 header , elements = _load_ply_raw (f , path_manager = path_manager )
9731019
974- verts , vertex_colors , vertex_normals = _get_verts (header , elements )
1020+ verts_data = _get_verts (header , elements )
9751021
9761022 face = elements .get ("face" , None )
9771023 if face is not None :
@@ -1007,9 +1053,9 @@ def _load_ply(
10071053 faces = torch .tensor (face_list , dtype = torch .int64 )
10081054
10091055 if faces is not None :
1010- _check_faces_indices (faces , max_index = verts .shape [0 ])
1056+ _check_faces_indices (faces , max_index = verts_data . verts .shape [0 ])
10111057
1012- return verts , faces , vertex_colors , vertex_normals
1058+ return _PlyData ( ** asdict ( verts_data ) , faces = faces , header = header )
10131059
10141060
10151061def load_ply (
@@ -1064,11 +1110,12 @@ def load_ply(
10641110
10651111 if path_manager is None :
10661112 path_manager = PathManager ()
1067- verts , faces , _ , _ = _load_ply (f , path_manager = path_manager )
1113+ data = _load_ply (f , path_manager = path_manager )
1114+ faces = data .faces
10681115 if faces is None :
10691116 faces = torch .zeros (0 , 3 , dtype = torch .int64 )
10701117
1071- return verts , faces
1118+ return data . verts , faces
10721119
10731120
10741121def _write_ply_header (
@@ -1305,20 +1352,20 @@ def read(
13051352 if not endswith (path , self .known_suffixes ):
13061353 return None
13071354
1308- verts , faces , verts_colors , verts_normals = _load_ply (
1309- f = path , path_manager = path_manager
1310- )
1355+ data = _load_ply (f = path , path_manager = path_manager )
1356+ faces = data .faces
13111357 if faces is None :
13121358 faces = torch .zeros (0 , 3 , dtype = torch .int64 )
13131359
13141360 texture = None
1315- if include_textures and verts_colors is not None :
1316- texture = TexturesVertex ([verts_colors .to (device )])
1361+ if include_textures and data . verts_colors is not None :
1362+ texture = TexturesVertex ([data . verts_colors .to (device )])
13171363
1318- if verts_normals is not None :
1319- verts_normals = [verts_normals ]
1364+ verts_normals = None
1365+ if data .verts_normals is not None :
1366+ verts_normals = [data .verts_normals .to (device )]
13201367 mesh = Meshes (
1321- verts = [verts .to (device )],
1368+ verts = [data . verts .to (device )],
13221369 faces = [faces .to (device )],
13231370 textures = texture ,
13241371 verts_normals = verts_normals ,
@@ -1392,14 +1439,17 @@ def read(
13921439 if not endswith (path , self .known_suffixes ):
13931440 return None
13941441
1395- verts , faces , features , normals = _load_ply (f = path , path_manager = path_manager )
1396- verts = verts .to (device )
1397- if features is not None :
1398- features = [features .to (device )]
1399- if normals is not None :
1400- normals = [normals .to (device )]
1442+ data = _load_ply (f = path , path_manager = path_manager )
1443+ features = None
1444+ if data .verts_colors is not None :
1445+ features = [data .verts_colors .to (device )]
1446+ normals = None
1447+ if data .verts_normals is not None :
1448+ normals = [data .verts_normals .to (device )]
14011449
1402- pointcloud = Pointclouds (points = [verts ], features = features , normals = normals )
1450+ pointcloud = Pointclouds (
1451+ points = [data .verts .to (device )], features = features , normals = normals
1452+ )
14031453 return pointcloud
14041454
14051455 def save (
0 commit comments