File tree Expand file tree Collapse file tree 2 files changed +16
-2
lines changed
torchao/prototype/moe_training Expand file tree Collapse file tree 2 files changed +16
-2
lines changed Original file line number Diff line number Diff line change @@ -40,7 +40,7 @@ def _scaled_grouped_mm(
4040 offs (int32 torch.Tensor): The offsets to use to mark the starting index of each group along dim0 of the A tensor.
4141 out_dtype (Optional[torch.dtype]): The dtype of the output tensor. Currently only torch.bfloat16 is supported.
4242 """
43- logger .info ("Using scaled_grouped_mm" )
43+ # logger.info("Using scaled_grouped_mm")
4444 return _Float8GroupedMM .apply (
4545 A ,
4646 B_t ,
Original file line number Diff line number Diff line change @@ -47,7 +47,6 @@ def __new__(
4747 cls ,
4848 tensor : torch .Tensor ,
4949 ):
50- # logger.info(f"ScaledGroupedMMTensor __new__: tensor.dtype={tensor.dtype}, dtype: {dtype}, shape: {tensor.shape}")
5150 return torch .Tensor ._make_wrapper_subclass (
5251 cls ,
5352 tensor .size (),
@@ -155,9 +154,24 @@ def fsdp_post_all_gather(
155154 ):
156155 (data ,) = all_gather_outputs
157156
157+ # For training step 1+, out=unshared param, so we need to copy data to `out``
158+ # if `self._data`` and `out` do not share the same storage.
159+ # Otherwise, if they do share the same storage, we can just return directly.
158160 if out is not None :
161+ assert isinstance (out , ScaledGroupedMMTensor ), f"{ type (out )} "
162+ if data .dtype == param_dtype :
163+ assert (
164+ data .untyped_storage ().data_ptr ()
165+ == out ._data .untyped_storage ().data_ptr ()
166+ )
167+ else :
168+ assert out ._data .dtype == param_dtype , (
169+ f"{ out ._data .dtype } { param_dtype } "
170+ )
171+ out ._data .copy_ (data )
159172 return
160173
174+ # For training step 0, out=None, so we need to return a new ScaledGroupedMMTensor.
161175 output = ScaledGroupedMMTensor (data )
162176 inner_tensors = (data ,)
163177 return output , inner_tensors
You can’t perform that action at this time.
0 commit comments