99
1010import torch
1111import torch .utils ._pytree as pytree
12+ from torch import nn
1213from torch ._prims_common import suggest_memory_format
14+ from torch .distributed .device_mesh import DeviceMesh
15+ from torch .distributed .fsdp import MixedPrecisionPolicy
1316
1417from torchao .prototype .moe_training import _scaled_grouped_mm
1518
1619logger : logging .Logger = logging .getLogger (__name__ )
1720
18-
1921_ops_to_preserve_subclass = {
2022 torch .ops .aten .empty_like .default ,
2123 torch .ops .aten .new_zeros .default ,
@@ -44,15 +46,14 @@ class ScaledGroupedMMTensor(torch.Tensor):
4446 def __new__ (
4547 cls ,
4648 tensor : torch .Tensor ,
47- dtype : torch .dtype ,
4849 ):
4950 return torch .Tensor ._make_wrapper_subclass (
5051 cls ,
5152 tensor .size (),
5253 strides = tensor .stride (),
5354 storage_offset = tensor .storage_offset (),
5455 memory_format = suggest_memory_format (tensor ),
55- dtype = dtype ,
56+ dtype = tensor . dtype ,
5657 layout = tensor .layout ,
5758 device = tensor .device ,
5859 pin_memory = tensor .is_pinned (),
@@ -62,14 +63,11 @@ def __new__(
6263 def __init__ (
6364 self ,
6465 tensor : torch .Tensor ,
65- dtype : torch .dtype ,
6666 ):
6767 self ._data = tensor
68- self ._dtype = dtype
6968
7069 @classmethod
7170 def __torch_function__ (cls , func , types , args , kwargs = {}):
72- logger .info (f"{ func .__name__ } , args: { args } , kwargs: { kwargs } " )
7371 # override the grouped mm op to use the differentiable _scaled_grouped_mm
7472 if func .__name__ == cls .grouped_mm_func_name :
7573 # Use torchao scaled grouped mm with dynamic quant for
@@ -98,19 +96,10 @@ def __torch_function__(cls, func, types, args, kwargs={}):
9896 def __torch_dispatch__ (cls , func , types , args , kwargs = {}):
9997 # detach is special case
10098 if func == torch .ops .aten .detach .default :
101- return ScaledGroupedMMTensor (args [0 ]._data , args [0 ]._dtype )
102-
103- # unwrap args and kwargs
104- dtype : Optional [torch .dtype ] = None
105-
106- def unwrap (t ):
107- nonlocal dtype
108- if dtype is None :
109- dtype = t ._dtype
110- else :
111- assert t ._dtype == dtype
112- return t ._data
99+ return ScaledGroupedMMTensor (args [0 ]._data )
113100
101+ # unwrap args/kwargs
102+ unwrap = lambda x : x ._data if isinstance (x , ScaledGroupedMMTensor ) else x
114103 args , kwargs = pytree .tree_map_only (
115104 ScaledGroupedMMTensor , unwrap , (args , kwargs or {})
116105 )
@@ -125,25 +114,33 @@ def unwrap(t):
125114 # wrap outputs back into ScaledGroupedMMTensor for ops that do preserve subclass
126115 return pytree .tree_map_only (
127116 torch .Tensor ,
128- lambda x : ScaledGroupedMMTensor (x , dtype ),
117+ lambda x : ScaledGroupedMMTensor (x ),
129118 out ,
130119 )
131120
132121 def __repr__ (self ):
133- return f"ScaledGroupedMMTensor(data={ self ._data } , dtype= { self . _dtype } )"
122+ return f"ScaledGroupedMMTensor(data={ self ._data } )"
134123
135124 def __tensor_flatten__ (self ):
136- return ["_data" ], { "_dtype" : self . _dtype }
125+ return ["_data" ]
137126
138127 @staticmethod
139128 def __tensor_unflatten__ (inner_tensors , flatten_spec , outer_size , outer_stride ):
140129 return ScaledGroupedMMTensor (
141130 inner_tensors ["_data" ],
142- flatten_spec ["_dtype" ],
143131 )
144132
145- def fsdp_pre_all_gather (self , mesh ):
146- all_gather_inputs = (self ._data ,)
133+ # fsdp hooks based on https://github.com/pytorch/pytorch/blob/20e40492b046b9287726d3ec656117e4dc38f0e2/test/distributed/_composable/fsdp/test_fully_shard_extensions.py#L81
134+ def fsdp_pre_all_gather (
135+ self ,
136+ mesh : DeviceMesh ,
137+ outer_size : torch .Size ,
138+ outer_stride : tuple [int , ...],
139+ module : nn .Module ,
140+ mp_policy : MixedPrecisionPolicy ,
141+ ):
142+ # cast to mixed precision dtype prior to all-gather
143+ all_gather_inputs = (self ._data .to (mp_policy .param_dtype ),)
147144 all_gather_metadata = ()
148145 return all_gather_inputs , all_gather_metadata
149146
@@ -156,6 +153,25 @@ def fsdp_post_all_gather(
156153 out : Optional [torch .Tensor ] = None ,
157154 ):
158155 (data ,) = all_gather_outputs
159- output = ScaledGroupedMMTensor (data , param_dtype )
156+
157+ # For training step 1+, out=unsharded param, so we need to copy data to `out`
158+ # if `self._data`` and `out` do not share the same storage.
159+ # Otherwise, if they do share the same storage, we can just return directly.
160+ if out is not None :
161+ assert isinstance (out , ScaledGroupedMMTensor ), f"{ type (out )} "
162+ if data .dtype == param_dtype :
163+ assert (
164+ data .untyped_storage ().data_ptr ()
165+ == out ._data .untyped_storage ().data_ptr ()
166+ )
167+ else :
168+ assert out ._data .dtype == param_dtype , (
169+ f"{ out ._data .dtype } { param_dtype } "
170+ )
171+ out ._data .copy_ (data )
172+ return
173+
174+ # For training step 0, out=None, so we need to return a new ScaledGroupedMMTensor.
175+ output = ScaledGroupedMMTensor (data )
160176 inner_tensors = (data ,)
161177 return output , inner_tensors
0 commit comments