Skip to content

Commit 000a490

Browse files
authored
Add a developer guide for exporting to executorch (#1219)
* Add a developer guide for exporting to executorch Summary: att, the requirement for exporting a quantized model to executorch is mainly that we want to preserve soem high level ops so they can be lowered to executorch ops, examples of ops that are already preserved are quantize_affine/dequantize_affine/choose_qparams_affine which can be matched in executorch for pattern matching, this PR adds an example for how to define and preserve a quantized embedding_byte op, the main util function we use is `torchao.utils._register_custom_op` Test Plan: python tutorials/developer_api_guide/export_to_executorch.py Reviewers: Subscribers: Tasks: Tags: * address comments * docs
1 parent 88d604f commit 000a490

File tree

2 files changed

+86
-1
lines changed

2 files changed

+86
-1
lines changed

torchao/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -276,7 +276,7 @@ def unwrap_tensor_subclass(model, filter_fn=None):
276276
for name, child in model.named_children():
277277
# make sure child.weight is a tensor subclass
278278
if (
279-
isinstance(child, torch.nn.Linear) and
279+
(isinstance(child, torch.nn.Linear) or isinstance(child, torch.nn.Embedding)) and
280280
hasattr(child, "weight") and
281281
type(child.weight) is not torch.Tensor and
282282
type(child.weight) is not torch.nn.Parameter and
Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
"""
2+
This tutorial shows how to preserve higher level operators in the model in order to be used in executorch
3+
4+
Specifically we define and preserved `torch.ops.quant.embedding_byte` op that works with quantized weights
5+
through `torch.export.export`, we can follow Executorch tutorials: https://pytorch.org/executorch/stable/tutorials/export-to-executorch-tutorial.html#lowering-to-edge-dialect to lower the model to executorch
6+
or rely on https://github.com/pytorch/executorch/tree/main/examples/models/llama and https://github.com/pytorch/torchchat libraries to export to target device.
7+
8+
This can also support exporting the model to other platforms like ONNX as well.
9+
"""
10+
import torch
11+
import torchao
12+
from my_dtype_tensor_subclass import (
13+
MyDTypeTensor,
14+
)
15+
from torchao.utils import _register_custom_op
16+
from torchao.quantization.quant_primitives import dequantize_affine
17+
from typing import Optional, List
18+
19+
quant_lib = torch.library.Library("quant", "FRAGMENT")
20+
register_custom_op = _register_custom_op(quant_lib)
21+
22+
class MyDTypeTensorExtended(MyDTypeTensor):
23+
pass
24+
25+
implements = MyDTypeTensorExtended.implements
26+
to_my_dtype_extended = MyDTypeTensorExtended.from_float
27+
28+
aten = torch.ops.aten
29+
30+
# NOTE: the op must start with `_`
31+
# NOTE: typing must be compatible with infer_schema (https://github.com/pytorch/pytorch/blob/main/torch/_library/infer_schema.py)
32+
# This will register a torch.ops.quant.embedding
33+
@register_custom_op
34+
def _embedding_byte(
35+
int_data: torch.Tensor,
36+
block_size: List[int],
37+
weight_scales: torch.Tensor,
38+
indices: torch.Tensor,
39+
) -> torch.Tensor:
40+
weight = dequantize_affine(
41+
int_data,
42+
block_size,
43+
weight_scales,
44+
None,
45+
int_data.dtype,
46+
)
47+
return torch.ops.aten.embedding.default(weight, indices)
48+
49+
50+
@implements(torch.nn.functional.embedding)
51+
def _(func, types, args, kwargs):
52+
indices = args[0]
53+
weight = args[1]
54+
tensor_impl = weight.tensor_impl
55+
int_data, scale = tensor_impl.get_plain()
56+
block_size = (1, int_data.shape[-1])
57+
return torch.ops.quant.embedding_byte(int_data, block_size, scale, indices)
58+
59+
60+
def main():
61+
group_size = 64
62+
m = torch.nn.Sequential(
63+
torch.nn.Embedding(4096, 128)
64+
)
65+
input = torch.randint(0, 4096, (1, 6))
66+
67+
m[0].weight = torch.nn.Parameter(to_my_dtype_extended(m[0].weight), requires_grad=False)
68+
y_ref = m[0].weight.dequantize()[input]
69+
y_q = m(input)
70+
from torchao.quantization.utils import compute_error
71+
sqnr = compute_error(y_ref, y_q)
72+
assert sqnr > 45.0
73+
74+
# export
75+
m_unwrapped = torchao.utils.unwrap_tensor_subclass(m)
76+
m_exported = torch.export.export(m_unwrapped, (input,)).module()
77+
y_q_exported = m_exported(input)
78+
79+
assert torch.equal(y_ref, y_q_exported)
80+
ops = [n.target for n in m_exported.graph.nodes]
81+
print(m_exported)
82+
assert torch.ops.quant.embedding_byte.default in ops
83+
84+
if __name__ == "__main__":
85+
main()

0 commit comments

Comments
 (0)