@@ -73,7 +73,7 @@ def __new__(
7373 cat_tensor_shape [1 ] += shard .size ()[1 ]
7474
7575 # in cases of sharding optimizer rowwise, we calculate total tensor size by "concat" on first tensor dimension
76- if len (local_shards ) > 1 and local_shards [0 ].ndim == 1 : # column -wise sharding
76+ if len (local_shards ) > 1 and local_shards [0 ].ndim == 1 : # row -wise sharding
7777 for shard in local_shards [1 :]:
7878 cat_tensor_shape [0 ] += shard .size ()[0 ]
7979
@@ -119,6 +119,7 @@ def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
119119 aten .copy_ .default : cls .handle_copy_ ,
120120 aten .zeros_like .default : cls .handle_zeros_like ,
121121 aten .empty_like .default : cls .handle_empty_like ,
122+ aten .constant_pad_nd .default : cls .handle_constant_pad_nd ,
122123 }
123124
124125 if func in dispatcher :
@@ -279,6 +280,208 @@ def handle_new_empty(args, kwargs):
279280 self_ls .local_offsets (),
280281 )
281282
283+ @staticmethod
284+ # pyre-fixme[3]: Return type must be annotated.
285+ # pyre-fixme[2]: Parameter must be annotated.
286+ def handle_constant_pad_nd (args , kwargs ):
287+ """
288+ Apply constant padding to LocalShardsWrapper.
289+
290+ The padding is based off of the following ideas:
291+ - The resulting wrapper represents the padded version of the logical tensor.
292+ - Each shard is padded based on the sharding type + dimension that is padded.
293+ - For instance, CW shards padded on the left most col will have only padding on the first CW shard.
294+ - Padding the top row will apply to all CW shards.
295+ """
296+ self_lsw = args [0 ]
297+ pad_spec = args [1 ]
298+ pad_value = args [2 ] if len (args ) > 2 else 0.0
299+
300+ if len (self_lsw .local_shards ()) == 0 :
301+ raise NotImplementedError (
302+ "Padding empty LocalShardsWrapper is not supported."
303+ )
304+
305+ local_shards = self_lsw .local_shards ()
306+
307+ if len (local_shards ) == 1 :
308+ padded_shard = torch .nn .functional .pad (
309+ local_shards [0 ], pad_spec , mode = "constant" , value = pad_value
310+ )
311+ return LocalShardsWrapper ([padded_shard ], self_lsw .local_offsets ())
312+
313+ padded_shards = list (local_shards )
314+
315+ if local_shards [0 ].ndim == 2 :
316+ # 2D Column-wise sharding: [pad_left, pad_right, pad_top, pad_bottom]
317+ if len (pad_spec ) == 2 :
318+ # Single dimension padding happens on the left most column
319+ pad_spec = pad_spec + [0 , 0 ]
320+
321+ if len (pad_spec ) != 4 :
322+ raise ValueError (
323+ f"Padding spec must be of length 4 for 2D tensors, got { len (pad_spec )} "
324+ )
325+
326+ pad_left , pad_right , pad_top , pad_bottom = (
327+ pad_spec [0 ],
328+ pad_spec [1 ],
329+ pad_spec [2 ],
330+ pad_spec [3 ],
331+ )
332+
333+ if pad_top > 0 :
334+ padded_shards = [
335+ torch .nn .functional .pad (
336+ shard , [0 , 0 , pad_top , 0 ], mode = "constant" , value = pad_value
337+ )
338+ for shard in padded_shards
339+ ]
340+ if pad_bottom > 0 :
341+ padded_shards = [
342+ torch .nn .functional .pad (
343+ shard , [0 , 0 , 0 , pad_bottom ], mode = "constant" , value = pad_value
344+ )
345+ for shard in padded_shards
346+ ]
347+ if pad_left > 0 :
348+ padded_shards [0 ] = torch .nn .functional .pad (
349+ padded_shards [0 ],
350+ [pad_left , 0 , 0 , 0 ],
351+ mode = "constant" ,
352+ value = pad_value ,
353+ )
354+ if pad_right > 0 :
355+ padded_shards [- 1 ] = torch .nn .functional .pad (
356+ padded_shards [- 1 ],
357+ [0 , pad_right , 0 , 0 ],
358+ mode = "constant" ,
359+ value = pad_value ,
360+ )
361+ elif local_shards [0 ].ndim == 1 :
362+ # 1D Row-wise sharding: [pad_top, pad_bottom]
363+ if len (pad_spec ) != 2 :
364+ raise ValueError (
365+ f"Padding spec must be of length 2 for 1D tensors, got { len (pad_spec )} "
366+ )
367+ pad_top , pad_bottom = pad_spec [0 ], pad_spec [1 ]
368+
369+ if pad_top > 0 :
370+ padded_shards [0 ] = torch .nn .functional .pad (
371+ padded_shards [0 ], [pad_top , 0 ], mode = "constant" , value = pad_value
372+ )
373+ if pad_bottom > 0 :
374+ padded_shards [- 1 ] = torch .nn .functional .pad (
375+ padded_shards [- 1 ], [0 , pad_bottom ], mode = "constant" , value = pad_value
376+ )
377+ else :
378+ raise NotImplementedError (
379+ f"Padding for { local_shards [0 ].ndim } D tensors is not supported. "
380+ f"Only 1D and 2D tensors are currently supported."
381+ )
382+
383+ # Update offsets and storage metadata
384+ original_storage = self_lsw .storage_metadata ()
385+ updated_offsets , updated_storage = LocalShardsWrapper ._compute_updated_metadata (
386+ original_storage ,
387+ self_lsw .local_offsets (),
388+ pad_spec ,
389+ local_shards [0 ].ndim ,
390+ padded_shards ,
391+ )
392+
393+ result = LocalShardsWrapper (padded_shards , updated_offsets )
394+ result ._storage_meta = updated_storage
395+ return result
396+
397+ @staticmethod
398+ def _compute_updated_metadata (
399+ original_storage : TensorStorageMetadata ,
400+ original_offsets : list [torch .Size ],
401+ pad_spec : list [int ],
402+ ndim : int ,
403+ padded_shards : list [torch .Tensor ],
404+ ) -> tuple [list [tuple [int , ...]], TensorStorageMetadata ]:
405+ """
406+ Compute updated offsets and storage metadata after padding is applied.
407+
408+ Args:
409+ original_storage: Original storage metadata
410+ original_offsets: Original shard offsets
411+ pad_spec: Padding specification
412+ ndim: Number of dimensions (1=RW or 2=CW)
413+ padded_shards: Padded shard tensors
414+
415+ Returns:
416+ Tuple of (updated_offsets, updated_storage_metadata)
417+ """
418+ if ndim == 1 : # 1D RW
419+ pad_top , pad_bottom = pad_spec [0 ], pad_spec [1 ]
420+
421+ updated_offsets = []
422+ for i , offset in enumerate (original_offsets ):
423+ if i == 0 :
424+ # First shard: offset stays the same (absorbs top padding)
425+ updated_offsets .append (tuple (offset ))
426+ else :
427+ # Subsequent shards: shift by top padding amount
428+ new_offset = (offset [0 ] + pad_top ,)
429+ updated_offsets .append (new_offset )
430+
431+ new_global_size = torch .Size (
432+ [original_storage .size [0 ] + pad_top + pad_bottom ]
433+ )
434+
435+ elif ndim == 2 : # 2D CW
436+ pad_left , pad_right , pad_top , pad_bottom = (
437+ pad_spec [0 ],
438+ pad_spec [1 ],
439+ pad_spec [2 ],
440+ pad_spec [3 ],
441+ )
442+
443+ updated_offsets = []
444+ for i , offset in enumerate (original_offsets ):
445+ row_offset = offset [0 ]
446+ col_offset = offset [1 ]
447+
448+ # Top/bottom padding doesn't affect offsets
449+ # Left padding affects column offsets
450+ if i == 0 :
451+ # First shard: column offset stays the same (absorbs left padding)
452+ new_2d_offset = (row_offset , col_offset )
453+ else :
454+ # Subsequent shards: shift column offset by left padding amount
455+ new_2d_offset = (row_offset , col_offset + pad_left )
456+
457+ updated_offsets .append (new_2d_offset )
458+
459+ new_global_size = torch .Size (
460+ [
461+ original_storage .size [0 ] + pad_top + pad_bottom ,
462+ original_storage .size [1 ] + pad_left + pad_right ,
463+ ]
464+ )
465+
466+ else :
467+ raise NotImplementedError (f"Metadata computation for { ndim } D not supported" )
468+
469+ updated_chunks = [
470+ ChunkStorageMetadata (
471+ offsets = torch .Size (offset ),
472+ sizes = shard .size (),
473+ )
474+ for offset , shard in zip (updated_offsets , padded_shards )
475+ ]
476+
477+ updated_storage = TensorStorageMetadata (
478+ properties = original_storage .properties ,
479+ size = new_global_size ,
480+ chunks = updated_chunks ,
481+ )
482+
483+ return updated_offsets , updated_storage
484+
282485 @property
283486 def device (self ) -> torch ._C .device : # type: ignore[override]
284487 return (
0 commit comments