@@ -28,6 +28,17 @@ def _aqt_is_int4(aqt):
2828 )
2929
3030
31+ def _same_metadata (self : "Int4PackedTensorImpl" , src : "Int4PackedTensorImpl" ) -> bool :
32+ return (
33+ isinstance (self , Int4PackedTensorImpl )
34+ and isinstance (src , Int4PackedTensorImpl )
35+ and self .shape == src .shape
36+ and self .int_data .shape == src .int_data .shape
37+ and self .scale .shape == src .scale .shape
38+ and type (self ._layout ) == type (src ._layout )
39+ )
40+
41+
3142@dataclass (frozen = True )
3243class CutlassInt4PackedLayout (Layout ):
3344 """Layout class for int4 packed layout for affine quantized tensor, for cutlass kernel."""
@@ -77,6 +88,18 @@ def __torch_dispatch__(cls, func, types, args, kwargs):
7788 func , args , kwargs , args [0 ]._apply_fn_to_data (torch .detach )
7889 )
7990
91+ elif func is aten .copy_ .default :
92+ self = args [0 ]
93+ src = args [1 ]
94+ if _same_metadata (self , src ):
95+ self_tensors = self .__tensor_flatten__ ()[0 ]
96+ for tensor_name in self_tensors :
97+ getattr (self , tensor_name ).copy_ (getattr (src , tensor_name ))
98+ return
99+ raise ValueError (
100+ f"Not supported args for copy_ due to metadata mistach: { args [0 ], args [1 ]} "
101+ )
102+
80103 raise NotImplementedError (
81104 f"Int4PackedTensorImpl dispatch: attempting to run { func } , this is not supported"
82105 )
0 commit comments