@@ -215,7 +215,8 @@ def rle_to_mask(rle: Dict[str, Any]) -> np.ndarray:
215215 return mask .transpose () # Put in C order
216216
217217
218- def _mask_to_rle_pytorch_2_0 (tensor : torch .Tensor ) -> (torch .Tensor , torch .Tensor , torch .Tensor ):
218+ @torch .compile (fullgraph = True , dynamic = True )
219+ def _mask_to_rle_pytorch_2_0_0 (tensor : torch .Tensor ) -> (torch .Tensor , torch .Tensor ):
219220 """
220221 Encodes masks to an uncompressed RLE, in the format expected by
221222 pycoco tools.
@@ -227,33 +228,53 @@ def _mask_to_rle_pytorch_2_0(tensor: torch.Tensor) -> (torch.Tensor, torch.Tenso
227228 with torch .autograd .profiler .record_function ("mask_to_rle_pytorch_2: change indices" ):
228229 # Compute change indices
229230 diff = tensor [:, 1 :] ^ tensor [:, :- 1 ]
230- a = torch .tensor ([[True ]])
231- if diff .is_cuda :
232- a = a .pin_memory ().cuda ()
233- # a = a.to(diff.device)
231+ # a = torch.tensor([[True]])
232+ a = torch .ones ((1 , 1 ), dtype = bool , device = diff .device )
233+ # if diff.is_cuda:
234+ # a = a.pin_memory().cuda()
235+ # # a = a.to(diff.device)
234236 a = a .expand_as (diff .narrow (1 , 0 , 1 ))
235237 diff = torch .cat ([a , diff , a ], dim = 1 )
236- if diff .numel () > 2147483646 :
237- num_chunks = (diff .numel () + 2147483646 ) // 2147483646
238- change_indices = torch .cat ([d .nonzero () for d in diff .chunk (num_chunks )])
239- else :
240- change_indices = diff .nonzero ()
238+ return diff
239+
240+
241+ @torch .compile (fullgraph = True , dynamic = True )
242+ def _mask_to_rle_pytorch_2_0_1 (tensor : torch .Tensor , diff : torch .Tensor , change_indices : torch .Tensor ) -> (torch .Tensor , torch .Tensor ):
243+ tensor = tensor .permute (0 , 2 , 1 ).flatten (1 )
241244
242245 with torch .autograd .profiler .record_function ("mask_to_rle_pytorch_2: all_btw_idx" ):
243246 alt_lens = diff .sum (dim = 1 )
244247
245248 all_cur_idx = change_indices [:, 1 ]
246- all_btw_idx = torch .cat ([all_cur_idx [1 :], all_cur_idx [:1 ]]) - all_cur_idx
249+ all_cur_idx_0 = all_cur_idx .narrow (0 , 1 , all_cur_idx .size (0 ) - 1 )
250+ all_cur_idx_1 = all_cur_idx .narrow (0 , 0 , 1 )
251+ all_btw_idx = torch .cat ([all_cur_idx_0 , all_cur_idx_1 ])
252+ all_btw_idx = all_btw_idx - all_cur_idx
247253
248254 with torch .autograd .profiler .record_function ("mask_to_rle_pytorch_2: Encode run length" ):
249255 alt_lens_nt = torch .nested .nested_tensor_from_jagged (all_btw_idx , lengths = alt_lens )
250256 # Encode run length
251257 counts_init = (tensor [:, 0 ] == 0 )
252- return RLEData (alt_lens_nt = alt_lens_nt ,
253- counts_init = counts_init ,
254- b = b ,
255- h = h ,
256- w = w )
258+ return alt_lens_nt , counts_init
259+
260+
261+ def _mask_to_rle_pytorch_2_0 (tensor : torch .Tensor ) -> RLEData :
262+ b , h , w = tensor .shape
263+ with torch .autograd .profiler .record_function ("mask_to_rle_pytorch_2: _mask_to_rle_pytorch_2_0_0" ):
264+ diff = _mask_to_rle_pytorch_2_0_0 (tensor )
265+ with torch .autograd .profiler .record_function ("mask_to_rle_pytorch_2: nonzero" ):
266+ if diff .numel () > 2147483646 :
267+ num_chunks = (diff .numel () + 2147483646 ) // 2147483646
268+ change_indices = torch .cat ([d .nonzero () for d in diff .chunk (num_chunks )])
269+ else :
270+ change_indices = diff .nonzero ()
271+ with torch .autograd .profiler .record_function ("mask_to_rle_pytorch_2: _mask_to_rle_pytorch_2_0_1" ):
272+ alt_lens_nt , counts_init = _mask_to_rle_pytorch_2_0_1 (tensor , diff , change_indices )
273+ return RLEData (alt_lens_nt = alt_lens_nt ,
274+ counts_init = counts_init ,
275+ b = b ,
276+ h = h ,
277+ w = w )
257278
258279
259280def _mask_to_rle_pytorch_2_1 (rle_data : RLEData ):
@@ -276,7 +297,8 @@ def _mask_to_rle_pytorch_2_1(rle_data: RLEData):
276297
277298
278299def mask_to_rle_pytorch_2 (tensor : torch .Tensor ) -> List [Dict [str , Any ]]:
279- return _mask_to_rle_pytorch_2_1 (_mask_to_rle_pytorch_2_0 (tensor ))
300+ with torch .autograd .profiler .record_function ("mask_to_rle_pytorch_2" ):
301+ return _mask_to_rle_pytorch_2_1 (_mask_to_rle_pytorch_2_0 (tensor ))
280302
281303
282304def area_from_rle (rle : Dict [str , Any ]) -> int :
0 commit comments