@@ -46,16 +46,15 @@ class ScaledGroupedMMTensor(torch.Tensor):
4646 def __new__ (
4747 cls ,
4848 tensor : torch .Tensor ,
49- dtype : torch .dtype ,
5049 ):
51- logger .info (f"ScaledGroupedMMTensor __new__: tensor.dtype={ tensor .dtype } , dtype: { dtype } , shape: { tensor .shape } " )
50+ # logger.info(f"ScaledGroupedMMTensor __new__: tensor.dtype={tensor.dtype}, dtype: {dtype}, shape: {tensor.shape}")
5251 return torch .Tensor ._make_wrapper_subclass (
5352 cls ,
5453 tensor .size (),
5554 strides = tensor .stride (),
5655 storage_offset = tensor .storage_offset (),
5756 memory_format = suggest_memory_format (tensor ),
58- dtype = dtype ,
57+ dtype = tensor . dtype ,
5958 layout = tensor .layout ,
6059 device = tensor .device ,
6160 pin_memory = tensor .is_pinned (),
@@ -65,15 +64,11 @@ def __new__(
6564 def __init__ (
6665 self ,
6766 tensor : torch .Tensor ,
68- dtype : torch .dtype ,
6967 ):
70- logger .info (f"ScaledGroupedMMTensor __init__: tensor.dtype={ tensor .dtype } , dtype: { dtype } , shape: { tensor .shape } " )
71- self ._data = tensor .to (dtype )
72- self ._dtype = dtype
68+ self ._data = tensor
7369
7470 @classmethod
7571 def __torch_function__ (cls , func , types , args , kwargs = {}):
76- logger .info (f"ScaledGroupedMMTensor func: { func .__name__ } , args: { args } , kwargs: { kwargs } " )
7772 # override the grouped mm op to use the differentiable _scaled_grouped_mm
7873 if func .__name__ == cls .grouped_mm_func_name :
7974 # Use torchao scaled grouped mm with dynamic quant for
@@ -102,7 +97,7 @@ def __torch_function__(cls, func, types, args, kwargs={}):
10297 def __torch_dispatch__ (cls , func , types , args , kwargs = {}):
10398 # detach is special case
10499 if func == torch .ops .aten .detach .default :
105- return ScaledGroupedMMTensor (args [0 ]._data , args [ 0 ]. _dtype )
100+ return ScaledGroupedMMTensor (args [0 ]._data )
106101
107102 # unwrap args/kwargs
108103 unwrap = lambda x : x ._data if isinstance (x , ScaledGroupedMMTensor ) else x
@@ -120,21 +115,20 @@ def __torch_dispatch__(cls, func, types, args, kwargs={}):
120115 # wrap outputs back into ScaledGroupedMMTensor for ops that do preserve subclass
121116 return pytree .tree_map_only (
122117 torch .Tensor ,
123- lambda x : ScaledGroupedMMTensor (x , x . dtype ),
118+ lambda x : ScaledGroupedMMTensor (x ),
124119 out ,
125120 )
126121
127122 def __repr__ (self ):
128- return f"ScaledGroupedMMTensor(data={ self ._data } , dtype= { self . _dtype } )"
123+ return f"ScaledGroupedMMTensor(data={ self ._data } )"
129124
130125 def __tensor_flatten__ (self ):
131- return ["_data" ], { "_dtype" : self . _dtype }
126+ return ["_data" ]
132127
133128 @staticmethod
134129 def __tensor_unflatten__ (inner_tensors , flatten_spec , outer_size , outer_stride ):
135130 return ScaledGroupedMMTensor (
136131 inner_tensors ["_data" ],
137- flatten_spec ["_dtype" ],
138132 )
139133
140134 # fsdp hooks based on https://github.com/pytorch/pytorch/blob/20e40492b046b9287726d3ec656117e4dc38f0e2/test/distributed/_composable/fsdp/test_fully_shard_extensions.py#L81
@@ -146,9 +140,9 @@ def fsdp_pre_all_gather(
146140 module : nn .Module ,
147141 mp_policy : MixedPrecisionPolicy ,
148142 ):
149- all_gather_inputs = (self ._data ,)
143+ # cast to mixed precision dtype prior to all-gather
144+ all_gather_inputs = (self ._data .to (mp_policy .param_dtype ),)
150145 all_gather_metadata = ()
151- #logger.info(f"ScaledGroupedMMTensor fsdp_pre_all_gather: self._data.dtype={self._data.dtype}, self._data.shape={self._data.shape}, param_dtype: {mp_policy.param_dtype}")
152146 return all_gather_inputs , all_gather_metadata
153147
154148 def fsdp_post_all_gather (
@@ -160,11 +154,10 @@ def fsdp_post_all_gather(
160154 out : Optional [torch .Tensor ] = None ,
161155 ):
162156 (data ,) = all_gather_outputs
163- #logger.info(f"ScaledGroupedMMTensor fsdp_post_all_gather: data.dtype={data.dtype}, param_dtype: {param_dtype}")
164157
165158 if out is not None :
166159 return
167160
168- output = ScaledGroupedMMTensor (data , param_dtype )
161+ output = ScaledGroupedMMTensor (data )
169162 inner_tensors = (data ,)
170163 return output , inner_tensors
0 commit comments