1+ from typing import Tuple
2+
13import torch
2- from typing import Tuple , Dict , List
34
45import torchao .sparsity .marlin .utils as utils
56from torchao .sparsity .marlin .utils import const
67from torchao .sparsity .utils import mask_creator
78
8-
99__all__ = [
1010 "inject_24" ,
1111 "marlin_24_workspace" ,
1414]
1515
1616
17- def inject_24 (w : torch .Tensor , size_k : int , size_n : int ) -> Tuple [torch .Tensor , torch .Tensor ]:
17+ def inject_24 (
18+ w : torch .Tensor , size_k : int , size_n : int
19+ ) -> Tuple [torch .Tensor , torch .Tensor ]:
1820 """Injects 2:4 sparsity into a weight tensor. The sparsity is applied in a 2:4 ratio, where for every
1921 group of 4 weights, 2 will be pruned based on their value. The mask will be created based on the
2022 ranked weight values.
21-
23+
2224 Args:
2325 w (torch.Tensor): The weight tensor to inject sparsity into.
2426 size_k (int): The number of input features.
@@ -32,33 +34,35 @@ def inject_24(w: torch.Tensor, size_k: int, size_n: int) -> Tuple[torch.Tensor,
3234
3335
3436def marlin_24_workspace (
35- out_features : int ,
36- min_thread_n : int = const .MIN_THREAD_N ,
37- max_parallel : int = const .MAX_PARALLEL
38- ) -> torch .Tensor :
39- """Creates a workspace for marlin 2:4 quantization. The workspace is used to coordinate the locks
37+ out_features : int ,
38+ min_thread_n : int = const .MIN_THREAD_N ,
39+ max_parallel : int = const .MAX_PARALLEL ,
40+ ) -> torch .Tensor :
41+ """Creates a workspace for marlin 2:4 quantization. The workspace is used to coordinate the locks
4042 during the execution of the kernel.
41-
43+
4244 Args:
4345 out_features (int): The number of output features.
4446 min_thread_n (int, optional): The minimum number of threads per block. Defaults to `MARLIN_24_MIN_THREAD_N`.
4547 max_parallel (int, optional): The maximum number of parallel threads. Defaults to `MARLIN_24_MAX_PARALLEL`.
4648 Returns:
4749 torch.Tensor: The workspace tensor fully initialized with zeros.
4850 """
49- assert (out_features % min_thread_n == 0 ), f"out_features = { out_features } , min_thread_n = { min_thread_n } "
50- max_workspace_size = ((out_features // min_thread_n ) * max_parallel )
51+ assert (
52+ out_features % min_thread_n == 0
53+ ), f"out_features = { out_features } , min_thread_n = { min_thread_n } "
54+ max_workspace_size = (out_features // min_thread_n ) * max_parallel
5155 return torch .zeros (max_workspace_size , dtype = torch .int , device = "cuda" )
5256
5357
5458def pack_to_marlin_24 (
55- q_w_24 : torch .Tensor ,
56- scales : torch .Tensor ,
57- num_bits : int ,
58- group_size : int ,
59- ) -> Tuple [torch .Tensor , torch .Tensor , torch .Tensor ]:
59+ q_w_24 : torch .Tensor ,
60+ scales : torch .Tensor ,
61+ num_bits : int ,
62+ group_size : int ,
63+ ) -> Tuple [torch .Tensor , torch .Tensor , torch .Tensor ]:
6064 """Packs the quantized weights and scales into the marlin 2:4 format.
61-
65+
6266 Args:
6367 q_w_24 (torch.Tensor): The quantized weight tensor with 2:4 sparsity applied.
6468 scales (torch.Tensor): The scale tensor.
@@ -89,13 +93,13 @@ def pack_to_marlin_24(
8993
9094
9195def unpack_from_marlin_24 (
92- q_w_24_comp : torch .Tensor ,
93- scales : torch .Tensor ,
94- meta : torch .Tensor ,
95- original_shape : torch .Size ,
96- group_size : int ,
97- num_bits : int
98- ) -> Tuple [torch .Tensor , torch .Tensor ]:
96+ q_w_24_comp : torch .Tensor ,
97+ scales : torch .Tensor ,
98+ meta : torch .Tensor ,
99+ original_shape : torch .Size ,
100+ group_size : int ,
101+ num_bits : int ,
102+ ) -> Tuple [torch .Tensor , torch .Tensor ]:
99103 """Unpacks the quantized weights and scales from the marlin 2:4 format.
100104 Args:
101105 q_w_24_comp (torch.Tensor): The packed quantized weights.
@@ -109,10 +113,8 @@ def unpack_from_marlin_24(
109113 """
110114 in_features , out_features = original_shape
111115
112- # Unpacks the scales
113- unpacked_scales = _from_marlin_scale (
114- scales , * original_shape , group_size , num_bits
115- )
116+ # Unpacks the scales
117+ unpacked_scales = _from_marlin_scale (scales , * original_shape , group_size , num_bits )
116118
117119 in_features_comp = in_features // 2
118120
@@ -130,14 +132,11 @@ def unpack_from_marlin_24(
130132
131133
132134def _compress_quantized_24_weight (
133- q_24 : torch .Tensor ,
134- size_k : int ,
135- size_n : int ,
136- num_bits : int
137- ) -> Tuple [torch .Tensor , torch .Tensor ]:
138- """Compresses the quantized weights to a 2:4 sparse format. Normalizes the weights over 0
135+ q_24 : torch .Tensor , size_k : int , size_n : int , num_bits : int
136+ ) -> Tuple [torch .Tensor , torch .Tensor ]:
137+ """Compresses the quantized weights to a 2:4 sparse format. Normalizes the weights over 0
139138 before compressing them.
140-
139+
141140 Args:
142141 q_24 (torch.Tensor): The quantized weight tensor.
143142 size_k (int): The number of input features.
@@ -168,14 +167,10 @@ def _compress_quantized_24_weight(
168167
169168
170169def _decompress_quantized_24_weight (
171- q_24_comp : torch .Tensor ,
172- meta : torch .Tensor ,
173- size_k : int ,
174- size_n : int ,
175- num_bits : int
176- ) -> torch .Tensor :
170+ q_24_comp : torch .Tensor , meta : torch .Tensor , size_k : int , size_n : int , num_bits : int
171+ ) -> torch .Tensor :
177172 """Decompresses the quantized weights from a 2:4 sparse format and restores the original shape.
178-
173+
179174 Args:
180175 q_24_comp (torch.Tensor): The compressed quantized weight tensor in 2:4 sparse format.
181176 meta (torch.Tensor): The meta tensor.
@@ -210,13 +205,13 @@ def _decompress_quantized_24_weight(
210205
211206
212207def _to_marlin_weights (
213- q_w : torch .Tensor ,
214- size_k : int ,
215- size_n : int ,
216- num_bits : int ,
217- ) -> torch .Tensor :
208+ q_w : torch .Tensor ,
209+ size_k : int ,
210+ size_n : int ,
211+ num_bits : int ,
212+ ) -> torch .Tensor :
218213 """Converts a quantized and 2:4 sparse format weight tensor to the marlin 2:4 format.
219-
214+
220215 Args:
221216 q_w (torch.Tensor): The quantized weight tensor in 2:4 sparse format.
222217 size_k (int): The number of input features.
@@ -236,7 +231,11 @@ def _to_marlin_weights(
236231 # Original implementation uses numpy + uint32 but we need to use int64 because torch.uint32
237232 # does not support rshift_cpu.
238233 q_w = q_w .cpu ().to (torch .int64 )
239- q_packed = torch .zeros ((q_w .shape [0 ], q_w .shape [1 ] // pack_factor ), dtype = torch .int64 , device = q_w .device )
234+ q_packed = torch .zeros (
235+ (q_w .shape [0 ], q_w .shape [1 ] // pack_factor ),
236+ dtype = torch .int64 ,
237+ device = q_w .device ,
238+ )
240239 for i in range (pack_factor ):
241240 q_packed |= q_w [:, i ::pack_factor ] << (num_bits * i )
242241
@@ -245,13 +244,10 @@ def _to_marlin_weights(
245244
246245
247246def _from_marlin_weights (
248- q_packed : torch .Tensor ,
249- size_k : int ,
250- size_n : int ,
251- num_bits : int
252- ) -> torch .Tensor :
247+ q_packed : torch .Tensor , size_k : int , size_n : int , num_bits : int
248+ ) -> torch .Tensor :
253249 """Converts a weight tensor in the marlin 2:4 format to a regular quantized 2:4 sparse format.
254-
250+
255251 Args:
256252 q_packed (torch.Tensor): The weight tensor in the marlin 2:4 format.
257253 size_k (int): The number of input features.
@@ -269,52 +265,54 @@ def _from_marlin_weights(
269265 # Original implementation uses numpy + uint32 but we need to use int64 because torch.uint32
270266 # does not support rshift_cpu.
271267 q_packed = q_packed .cpu ().to (torch .int64 )
272- q_w_unpacked = torch .zeros ((q_packed .shape [0 ], q_packed .shape [1 ] * pack_factor ), dtype = torch .int64 , device = q_packed .device )
268+ q_w_unpacked = torch .zeros (
269+ (q_packed .shape [0 ], q_packed .shape [1 ] * pack_factor ),
270+ dtype = torch .int64 ,
271+ device = q_packed .device ,
272+ )
273273 for i in range (pack_factor ):
274- q_w_unpacked [:, i ::pack_factor ] = (q_packed >> (num_bits * i )) & ((1 << num_bits ) - 1 )
274+ q_w_unpacked [:, i ::pack_factor ] = (q_packed >> (num_bits * i )) & (
275+ (1 << num_bits ) - 1
276+ )
275277
276278 q_w_unpacked = q_w_unpacked .to (orig_device , dtype = torch .int32 )
277279
278- q_w_comp = utils .reverse_marlin_permute_weights (q_w_unpacked , size_k , size_n , perm_24 )
280+ q_w_comp = utils .reverse_marlin_permute_weights (
281+ q_w_unpacked , size_k , size_n , perm_24
282+ )
279283 return q_w_comp
280284
281285
282286def _to_marlin_scales (
283- scales : torch .Tensor ,
284- size_k : int ,
285- size_n : int ,
286- group_size : int ,
287- num_bits : int
288- ) -> torch .Tensor :
287+ scales : torch .Tensor , size_k : int , size_n : int , group_size : int , num_bits : int
288+ ) -> torch .Tensor :
289289 """Converts a scale tensor to the format necessary for marlin.
290290 Args:
291291 scales (torch.Tensor): The scale tensor.
292292 size_k (int): The number of input features.
293293 size_n (int): The number of output features.
294294 group_size (int): The group size that was applied during quantization.
295295 num_bits (int): The number of bits used for quantization.
296-
296+
297297 Returns:
298298 torch.Tensor: The scale tensor in the marlin format.
299299 """
300300 _ , scale_perm_24 , scale_perm_single_24 = utils .get_perms_24 (num_bits )
301301 if group_size < size_k and group_size != - 1 :
302302 scales = scales .reshape ((- 1 , len (scale_perm_24 )))[:, scale_perm_24 ]
303303 else :
304- scales = scales .reshape ((- 1 , len (scale_perm_single_24 )))[:, scale_perm_single_24 ]
304+ scales = scales .reshape ((- 1 , len (scale_perm_single_24 )))[
305+ :, scale_perm_single_24
306+ ]
305307 scales = scales .reshape ((- 1 , size_n )).contiguous ()
306308 return scales
307309
308310
309311def _from_marlin_scale (
310- scales : torch .Tensor ,
311- size_k : int ,
312- size_n : int ,
313- group_size : int ,
314- num_bits : int
315- ) -> torch .Tensor :
312+ scales : torch .Tensor , size_k : int , size_n : int , group_size : int , num_bits : int
313+ ) -> torch .Tensor :
316314 """Converts a scale tensor from the marlin format to their original format.
317-
315+
318316 Args:
319317 scales (torch.Tensor): The scale tensor in the marlin format.
320318 size_k (int): The number of input features.
@@ -329,5 +327,7 @@ def _from_marlin_scale(
329327 scales = scales .reshape ((- 1 , len (scale_perm_24 )))[:, scale_perm_24 ]
330328 return scales .reshape ((size_k // group_size , size_n ))
331329 else :
332- scales = scales .reshape ((- 1 , len (scale_perm_single_24 )))[:, scale_perm_single_24 ]
333- return scales .reshape ((1 , - 1 ))
330+ scales = scales .reshape ((- 1 , len (scale_perm_single_24 )))[
331+ :, scale_perm_single_24
332+ ]
333+ return scales .reshape ((1 , - 1 ))
0 commit comments