Skip to content

Commit b195f11

Browse files
committed
(wip) gemlite integration and llama batchsize>1
Summary: compile isn't working with gemlite, probably need to rewrite the kernel wrapper in a more compatible way, added batch size > 1 to llama model see benchmark_results.txt for numbers Test Plan: Reviewers: Subscribers: Tasks: Tags: new gemlite integration using pip install Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: tests ran Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: fixing gemlite to do int4 matmul instead of fp16 fp16 Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: running tests Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: more testing Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: AQT integration wip Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: Wip Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: testing on gemlite a100_int8_tuning branch Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: gemlite subclass testing bitpacking 8 bits Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: bug fixing stuff Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: hicham fixes Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: new benchmarks Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: testing gemlite 8 bit Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: WIP Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
1 parent 039cef4 commit b195f11

File tree

9 files changed

+900
-149
lines changed

9 files changed

+900
-149
lines changed

torchao/_models/llama/benchmark_results.txt

Lines changed: 280 additions & 51 deletions
Large diffs are not rendered by default.

torchao/_models/llama/benchmarks.sh

Lines changed: 185 additions & 91 deletions
Large diffs are not rendered by default.

torchao/_models/llama/generate.py

Lines changed: 111 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import torchao
2020
from torchao.quantization.quant_primitives import MappingType
2121
from torchao.utils import get_model_size_in_bytes, TORCH_VERSION_AT_LEAST_2_5
22+
from torchao.utils import unwrap_tensor_subclass
2223

2324
torch.sparse.SparseSemiStructuredTensor._FORCE_CUTLASS = False
2425

@@ -171,7 +172,7 @@ def decode_n_tokens(
171172
)
172173
next_token, next_prob = next_token.clone(), next_prob.clone()
173174
input_pos += 1
174-
new_tokens.append(next_token)
175+
new_tokens.append(next_token.clone())
175176
callback(new_tokens[-1])
176177
new_probs.append(next_prob)
177178
cur_token = next_token
@@ -368,6 +369,7 @@ def ffn_or_attn_only(mod, fqn):
368369
int8_weight_only,
369370
quantize_,
370371
uintx_weight_only,
372+
gemlite_uintx_weight_only,
371373
)
372374

373375
from torchao.quantization.granularity import PerRow, PerTensor
@@ -377,6 +379,113 @@ def ffn_or_attn_only(mod, fqn):
377379
from torchao.prototype.spinquant import apply_spinquant
378380

379381
apply_spinquant(model)
382+
if "gemsub" in quantization:
383+
import os, pwd
384+
import gemlite
385+
from gemlite.core import GemLiteLinearTriton, set_autotune
386+
_quant_args = quantization.split("-")
387+
bit_width = int(_quant_args[-2])
388+
group_size = None if _quant_args[-1] == 'None' else int(_quant_args[-1]) # TODO is 'None' working?
389+
try:
390+
packing_bitwidth = int(_quant_args[-3])
391+
except:
392+
packing_bitwidth = 8
393+
394+
quantize_(model, gemlite_uintx_weight_only(group_size, bit_width, packing_bitwidth))
395+
396+
# try to load gemlite kernel config
397+
try:
398+
GemLiteLinearTriton.load_config(f"/tmp/{pwd.getpwuid(os.getuid()).pw_gecos}_gemlite.json")
399+
except:
400+
pass
401+
print("running calibration")
402+
generate(
403+
model,
404+
encode_tokens(tokenizer, prompt, bos=True, device=device),
405+
max_new_tokens,
406+
batch_size,
407+
interactive=False,
408+
temperature=temperature,
409+
top_k=top_k,
410+
)
411+
412+
GemLiteLinearTriton.cache_config(f"/tmp/{pwd.getpwuid(os.getuid()).pw_gecos}_gemlite.json")
413+
if "gemlite" in quantization:
414+
import gemlite
415+
import hqq
416+
from gemlite.core import GemLiteLinearTriton, DType, set_autotune
417+
from torchao.quantization.quant_api import _replace_with_custom_fn_if_matches_filter, _is_linear
418+
from hqq.core.quantize import HQQLinear, BaseQuantizeConfig
419+
_quant_args = quantization.split("-")
420+
421+
W_nbits = int(_quant_args[-2])
422+
group_size = None if _quant_args[-1] == 'None' else int(_quant_args[-1]) #None is channel-wise
423+
424+
425+
assert W_nbits in [1, 2, 4, 8], f"W_nbits needs to be in [1, 2, 4, 8], got {W_nbits} for gemlite-<W_nbits>-<group_size>"
426+
assert group_size in [32, 64, 128, 256, 512, 1024, None], f"group_size needs to be in [32, 64, 128, 256, 512, 1024, None], got {group_size} for gemlite-<W_nbits>-<group_size>"
427+
assert precision == torch.float16, f"gemlite only supports float16 precision, got {precision}"
428+
429+
430+
431+
quant_config = BaseQuantizeConfig(nbits=W_nbits, group_size=group_size, quant_zero=False, quant_scale=False, axis=1)
432+
quant_config['weight_quant_params']['optimize'] = False
433+
434+
set_autotune({'GEMV_REVSPLITK':True, 'GEMV':True, 'GEMM_SPLITK':True, 'GEMM':True}, exhaustive=False, use_cuda_graph=False)
435+
436+
def replace_fn(mod):
437+
if not isinstance(mod, torch.nn.Linear):
438+
return mod
439+
440+
in_features = mod.in_features
441+
out_features = mod.out_features
442+
443+
compute_dtype = mod.weight.dtype
444+
input_dtype, output_dtype = DType.FP16, DType.FP16
445+
446+
447+
hqq_layer = HQQLinear(mod, quant_config=quant_config, compute_dtype=compute_dtype, device=device, del_orig=False)
448+
if(hqq_layer.meta["group_size"] is None):
449+
hqq_layer.meta["group_size"] = hqq_layer.in_features
450+
451+
gemlite_linear = GemLiteLinearTriton(
452+
hqq_layer.meta["nbits"],
453+
group_size=hqq_layer.meta["group_size"],
454+
in_features=hqq_layer.in_features,
455+
out_features=hqq_layer.out_features,
456+
input_dtype=DType.FP16,
457+
output_dtype=DType.FP16,
458+
)
459+
orig_shape = hqq_layer.meta['shape']
460+
W_q = hqq_layer.unpack(dtype=torch.uint8).view(orig_shape) #Expects uint8 for Wn quantization!
461+
scales = hqq_layer.meta['scale'].clone()
462+
zeros = hqq_layer.meta['zero'].clone()
463+
bias = hqq_layer.bias.clone() if (hqq_layer.bias is not None) else None
464+
gemlite_linear.pack(W_q, scales, zeros, bias=bias, fma_mode=False, packing_bitwidth=32, contiguous=False)
465+
466+
del hqq_layer.W_q
467+
del hqq_layer.meta
468+
del hqq_layer
469+
torch.cuda.empty_cache()
470+
471+
return gemlite_linear
472+
473+
474+
_replace_with_custom_fn_if_matches_filter(model, replace_fn, _is_linear)
475+
import gc
476+
gc.collect()
477+
478+
generate(
479+
model,
480+
encode_tokens(tokenizer, prompt, bos=True, device=device),
481+
max_new_tokens,
482+
batch_size,
483+
interactive=False,
484+
temperature=temperature,
485+
top_k=top_k,
486+
)
487+
488+
380489
if "int8wo" in quantization:
381490
quantize_(model, int8_weight_only())
382491
if "int8dq" in quantization:
@@ -1053,6 +1162,7 @@ def callback(x):
10531162
)
10541163

10551164
args = parser.parse_args()
1165+
print(args)
10561166
main(
10571167
args.prefill_size,
10581168
args.prompt,

torchao/_models/llama/model.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,9 @@ def setup_caches(self, max_batch_size, max_seq_length, training: bool=False, kv_
170170
max_seq_length = find_multiple(max_seq_length, 8)
171171
self.max_seq_length = max_seq_length
172172
self.max_batch_size = max_batch_size
173-
dtype = self.output.weight.dtype
173+
dtype = None
174+
if hasattr(self.output, "weight"):
175+
dtype = self.output.weight.dtype
174176
# For quantized layers, dtype is encoded in scales
175177
if hasattr(self.output, "scales"):
176178
dtype = self.output.scales.dtype
@@ -243,7 +245,11 @@ def forward(self, idx: Tensor, input_pos: Optional[Tensor] = None) -> Tensor:
243245
x = self.tok_embeddings(idx)
244246

245247
for i, layer in enumerate(self.layers):
246-
x = layer(x, input_pos, freqs_cis, mask)
248+
x_new = layer(x, input_pos, freqs_cis, mask)
249+
# if torch.isnan(x_new).sum()>0:
250+
# import fbvscode; fbvscode.set_trace()
251+
x = x_new
252+
247253
x = self.norm(x)
248254
logits = self.output(x)
249255
return logits
@@ -311,7 +317,7 @@ def forward(self, x: Tensor, freqs_cis: Tensor, mask: Optional[Tensor], input_po
311317

312318
k = k.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
313319
v = v.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
314-
if mask is not None:
320+
if mask is not None:
315321
y = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0)
316322
else:
317323
y = F.scaled_dot_product_attention(q, k, v, dropout_p=0.0, is_causal=True)

torchao/dtypes/affine_quantized_tensor.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -225,6 +225,7 @@ def from_hp_to_intx(
225225
else input_float.dtype
226226
)
227227
device = input_float.device
228+
from torchao.dtypes.uintx import TensorCoreTiledLayout
228229
data, scale, zero_point, _ = choose_qparams_and_quantize_affine_hqq(
229230
input_float,
230231
nbits=nbits,
@@ -233,7 +234,12 @@ def from_hp_to_intx(
233234
compute_dtype=compute_dtype,
234235
device=device,
235236
verbose=False,
236-
raw_output=False,
237+
raw_output=not isinstance(_layout, TensorCoreTiledLayout),
238+
# raw_output=False is basically the 'convert to TensorCoreTiledLayout zero_point version' option (add scale*midpoint)
239+
# note in choose_qparams_affine, preserve_zero = False does this same thing while also controlling whether
240+
# zero is preserved.
241+
# TODO uncouple preserve_zero and conversion of zero_point to TensorCoreTiledLayout version
242+
# TODO move the conversion of zero_point out of quant_primitives and into TensorCoreTiledLayout.from_plain
237243
)
238244
data = data.to(target_dtype)
239245
else:

torchao/dtypes/affine_quantized_tensor_ops.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,10 @@
4343
_linear_bf16_act_uint4_weight_check,
4444
_linear_bf16_act_uint4_weight_impl,
4545
)
46+
from torchao.dtypes.uintx.gemlite_layout import (
47+
_linear_fp_act_int4_weight_gemlite_check,
48+
_linear_fp_act_int4_weight_gemlite_impl,
49+
)
4650
from torchao.quantization.quant_primitives import dequantize_affine
4751
from torchao.utils import (
4852
fill_defaults,
@@ -135,6 +139,10 @@ def _register_aqt_quantized_linear_dispatches():
135139
_linear_int8_act_int4_weight_marlin_qqq_check,
136140
_linear_int8_act_int4_weight_marlin_qqq_impl,
137141
),
142+
(
143+
_linear_fp_act_int4_weight_gemlite_check,
144+
_linear_fp_act_int4_weight_gemlite_impl,
145+
)
138146
]:
139147
register_aqt_quantized_linear_dispatch(dispatch_condition, impl)
140148

0 commit comments

Comments
 (0)