@@ -175,8 +175,6 @@ def from_float(
175175 dtype = torch .int16
176176 scale , zero_point = choose_qparams_affine (input_float , mapping_type , block_size , dtype )
177177 int_data = quantize_affine (input_float , block_size , scale , zero_point , dtype )
178- # int_data = (input_float / scale).to(torch.int8)
179- print ("initial:" , scale .shape , " int data:" , int_data .shape )
180178 layout_tensor_ctr = get_layout_tensor_constructor (type (layout_type ))
181179 layout_tensor = layout_tensor_ctr (int_data , scale , layout_type )
182180 return cls (layout_tensor , input_float .shape )
@@ -309,6 +307,7 @@ def __torch_dispatch__(cls, func, types, args, kwargs):
309307 func , args , kwargs , args [0 ]._apply_fn_to_data (torch .detach )
310308 )
311309
310+ # Tensor parallel support START
312311 elif func in [aten ._to_copy .default , aten .clone .default ]:
313312 return return_and_correct_aliasing (
314313 func , args , kwargs , args [0 ]._apply_fn_to_data (torch .clone )
@@ -334,6 +333,8 @@ def __torch_dispatch__(cls, func, types, args, kwargs):
334333 elif func is aten .t .default :
335334 return return_and_correct_aliasing (func , args , kwargs , PlainMyDTypeLayout (args [0 ].int_data , args [0 ].scale , not args [0 ].transposed , args [0 ].layout_type ))
336335
336+ # Tensor parallel support END
337+
337338 raise NotImplementedError (
338339 f"PlainMyDTypeLayout dispatch: attempting to run { func } , this is not supported"
339340 )
0 commit comments