Skip to content

Commit 5bdee5e

Browse files
committed
float8 in quantize()
1 parent c842d50 commit 5bdee5e

File tree

2 files changed

+34
-4
lines changed

2 files changed

+34
-4
lines changed

torchao/_models/llama/eval.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020
fpx_weight_only,
2121
uintx_weight_only,
2222
unwrap_tensor_subclass,
23+
float8_weight_only,
24+
float8_dynamic_activation_float8_weight,
2325
)
2426
from torchao._models._eval import TransformerEvalWrapper, InputRecorder
2527

@@ -28,6 +30,7 @@
2830
from torchao.quantization.GPTQ import Int4WeightOnlyGPTQQuantizer
2931
from torchao._models.llama.model import prepare_inputs_for_model
3032
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5
33+
from torchao.quantization.observer import PerTensor, PerRow
3134

3235
def run_evaluation(
3336
checkpoint_path: Path,
@@ -117,7 +120,19 @@ def run_evaluation(
117120
else:
118121
if not TORCH_VERSION_AT_LEAST_2_5:
119122
unwrap_tensor_subclass(model)
120-
123+
if "float8wo" in quantization:
124+
quantize_(model, float8_weight_only())
125+
if "float8dq" in quantization:
126+
granularity = int(quantization.split("-")[-2])
127+
if granularity is None:
128+
granularity = PerTensor
129+
if granularity=="tensor":
130+
granularity = PerTensor
131+
elif granularity=="row":
132+
granularity = PerRow
133+
else:
134+
raise ValueError(f"float8dq granularity needs to be either tensor or row but got {granularity}")
135+
quantize_(model, float8_dynamic_activation_float8_weight(granularity=granularity))
121136
if compile:
122137
model = torch.compile(model, mode="max-autotune", fullgraph=True)
123138
with torch.no_grad():
@@ -140,7 +155,7 @@ def run_evaluation(
140155
parser.add_argument('--limit', type=int, default=None, help='Number of eval samples to evaluate')
141156
parser.add_argument('--precision', type=lambda x: getattr(torch, x.split(".")[-1]), default=torch.bfloat16, help='dtype precision to use')
142157
parser.add_argument('--device', type=str, default="cuda", help='Device to use for evaluation')
143-
parser.add_argument("-q", "--quantization", type=str, help="Which quantization techniques to apply: int8dq, int8wo, int4wo-<groupsize>, int4wo-<groupsize>-gptq, int4wo-<groupsize>-hqq, uintx-<nbits>-<groupsize>, uintx-<nbits>-<groupsize>-hqq")
158+
parser.add_argument("-q", "--quantization", type=str, help="Which quantization techniques to apply: int8dq, int8wo, int4wo-<groupsize>, int4wo-<groupsize>-gptq, int4wo-<groupsize>-hqq, uintx-<nbits>-<groupsize>, uintx-<nbits>-<groupsize>-hqq, float8wo, float8dq-<granularity>")
144159
parser.add_argument('--compile', action='store_true', help='Whether to compile the model.')
145160
parser.add_argument('--max_length', type=int, default=None, help='Length of text to process at one time')
146161
parser.add_argument('--calibration_tasks', type=str, nargs='+', default=['wikitext'], help='tasks to do gptq calibration on, if doing gptq')

torchao/_models/llama/generate.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -210,7 +210,9 @@ def main(
210210
fpx_weight_only,
211211
uintx_weight_only,
212212
autoquant,
213-
unwrap_tensor_subclass
213+
unwrap_tensor_subclass,
214+
float8_weight_only,
215+
float8_dynamic_activation_float8_weight,
214216
)
215217
if "int8wo" in quantization:
216218
quantize_(model, int8_weight_only())
@@ -290,6 +292,19 @@ def main(
290292
dtype = _NBITS_TO_DTYPE[nbits]
291293
group_size = int(_quant_args[1])
292294
quantize_(model, uintx_weight_only(dtype, group_size, use_hqq=use_hqq))
295+
if "float8wo" in quantization:
296+
quantize_(model, float8_weight_only())
297+
if "float8dq" in quantization:
298+
granularity = int(quantization.split("-")[-2])
299+
if granularity is None:
300+
granularity = PerTensor
301+
if granularity=="tensor":
302+
granularity = PerTensor
303+
elif granularity=="row":
304+
granularity = PerRow
305+
else:
306+
raise ValueError(f"float8dq granularity needs to be either tensor or row but got {granularity}")
307+
quantize_(model, float8_dynamic_activation_float8_weight(granularity=granularity))
293308
if "autoquant" in quantization:
294309
if "autoquant-int4" == quantization:
295310
model = autoquant(model, manual=True, qtensor_class_list = torchao.quantization.DEFAULT_INT4_AUTOQUANT_CLASS_LIST)
@@ -459,7 +474,7 @@ def callback(x):
459474
parser.add_argument('--top_k', type=int, default=200, help='Top-k for sampling.')
460475
parser.add_argument('--temperature', type=float, default=0.8, help='Temperature for sampling.')
461476
parser.add_argument('--checkpoint_path', type=Path, default=Path("../../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth"), help='Model checkpoint path.')
462-
parser.add_argument('-q', '--quantization', type=str, help='Which quantization techniques to apply: int8dq, int8wo, int4wo-<groupsize>, autoquant, autoquant-int4, int4wo-<groupsize>-hqq, autoround-<model_device>-<quant_lm_head>-<iters>-<groupsize>-<batch_size>-<seqlen>-<nsamples>, uintx-<nbits>-<groupsize>, uintx-<nbits>-<groupsize>-hqq')
477+
parser.add_argument('-q', '--quantization', type=str, help='Which quantization techniques to apply: int8dq, int8wo, int4wo-<groupsize>, autoquant, autoquant-int4, int4wo-<groupsize>-hqq, autoround-<model_device>-<quant_lm_head>-<iters>-<groupsize>-<batch_size>-<seqlen>-<nsamples>, uintx-<nbits>-<groupsize>, uintx-<nbits>-<groupsize>-hqq, float8wo, float8dq-<granularity>')
463478
parser.add_argument('--kv_cache_quantization', action='store_true', help='Whether to quantize the KV cache')
464479
parser.add_argument('--cache_size', type=int, default=None, help='Force size of cache to be a certain number of tokens, if not set, will use max_new_tokens+prompt_size')
465480
parser.add_argument('--linear_causal_mask', action='store_true', help='Whether to use the memory efficient, but slightly less fast, linear causal mask (important for long context lengths)')

0 commit comments

Comments
 (0)