From 3fd99631988b53ebf56657497e2fa13550fefbdd Mon Sep 17 00:00:00 2001 From: Mengwei Liu Date: Tue, 12 Nov 2024 15:09:41 -0800 Subject: [PATCH 1/2] [llama-mm] Add unit tests for exporting MultiHeadAttention with KVCache Summary: To make sure we can always export MultiHeadAttention. Test Plan: Reviewers: Subscribers: Tasks: Tags: --- extension/llm/modules/kv_cache.py | 5 +- extension/llm/modules/test/test_attention.py | 67 ++++++++++++++------ 2 files changed, 50 insertions(+), 22 deletions(-) diff --git a/extension/llm/modules/kv_cache.py b/extension/llm/modules/kv_cache.py index 827078a40a8..eb95cab0838 100644 --- a/extension/llm/modules/kv_cache.py +++ b/extension/llm/modules/kv_cache.py @@ -105,9 +105,8 @@ def update( f", but found new key tensors with batch size {k_val.shape[0]}!" ) - assert ( - self.cache_pos[0] + seq_len - ) <= self.max_seq_len, f"self.cache_pos[0]: {self.cache_pos[0]} + seq_len: {seq_len} > self.max_seq_len: {self.max_seq_len}" + assert (self.cache_pos[0] + seq_len) <= self.max_seq_len + k_out = self.k_cache v_out = self.v_cache diff --git a/extension/llm/modules/test/test_attention.py b/extension/llm/modules/test/test_attention.py index 565e8c67d75..3c043c130b1 100644 --- a/extension/llm/modules/test/test_attention.py +++ b/extension/llm/modules/test/test_attention.py @@ -4,6 +4,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import os +import tempfile import unittest import torch @@ -13,6 +15,7 @@ MultiHeadAttention as ETMultiHeadAttention, ) from executorch.runtime import Runtime +from torch._inductor.package import load_package, package_aoti from torch.testing import assert_close from torchtune.models.llama3_1._position_embeddings import Llama3ScaledRoPE from torchtune.modules.attention import MultiHeadAttention as TTMultiHeadAttention @@ -130,34 +133,62 @@ def test_attention_eager(self): def test_attention_export(self): # Self attention. - et_mha_ep = torch.export.export( - self.et_mha, - (self.x, self.x), - kwargs={"input_pos": self.input_pos}, - dynamic_shapes=self.dynamic_shapes, - ) + + # test with kv cache + self.et_mha.setup_cache(1, dtype=torch.float32, max_seq_len=100) + self.tt_mha.setup_cache(1, dtype=torch.float32, max_seq_len=100) + with torch.no_grad(): + et_mha_ep = torch.export.export( + self.et_mha, + (self.x, self.x), + kwargs={"input_pos": self.input_pos}, + dynamic_shapes=self.dynamic_shapes, + ) et_res = et_mha_ep.module()(self.x, self.x, input_pos=self.input_pos) tt_res = self.tt_mha(self.x, self.x, input_pos=self.input_pos) assert_close(et_res, tt_res) - # TODO: KV cache. - def test_attention_aoti(self): - # TODO. - pass + # Self attention. + + # test with kv cache + self.et_mha.setup_cache(1, dtype=torch.float32, max_seq_len=100) + self.tt_mha.setup_cache(1, dtype=torch.float32, max_seq_len=100) + with torch.no_grad(): + so = torch._export.aot_compile( + self.et_mha, + args=(self.x, self.x), + kwargs={"input_pos": self.input_pos}, + options={"aot_inductor.package": True}, + dynamic_shapes=self.dynamic_shapes, + ) + with tempfile.TemporaryDirectory() as tempdir: + path = package_aoti(os.path.join(tempdir, "mha.pt2"), so) + mha_aoti = load_package(path) + + et_res = mha_aoti(self.x, self.x, input_pos=self.input_pos) + tt_res = self.tt_mha(self.x, self.x, input_pos=self.input_pos) + self.assertTrue(torch.allclose(et_res, tt_res)) def test_attention_executorch(self): # Self attention. - et_mha_ep = torch.export.export( - self.et_mha, - (self.x, self.x), - kwargs={"input_pos": self.input_pos}, - dynamic_shapes=self.dynamic_shapes, - ) + # TODO: Fix kv cache + # self.et_mha.setup_cache(1, dtype=torch.float32, max_seq_len=100) + # self.tt_mha.setup_cache(1, dtype=torch.float32, max_seq_len=100) + + with torch.no_grad(): + et_mha_ep = torch.export.export( + self.et_mha, + (self.x, self.x), + kwargs={"input_pos": self.input_pos}, + dynamic_shapes=self.dynamic_shapes, + ) et_program = to_edge( et_mha_ep, - compile_config=EdgeCompileConfig(), + compile_config=EdgeCompileConfig( + _core_aten_ops_exception_list=[torch.ops.aten._assert_async.msg] + ), ).to_executorch() runtime = Runtime.get() program = runtime.load_program(et_program.buffer) @@ -166,5 +197,3 @@ def test_attention_executorch(self): tt_res = self.tt_mha(self.x, self.x, input_pos=self.input_pos) assert_close(et_res[0], tt_res) - - # TODO: KV cache. From 4af2c24c2da6998c408db3cfcd79e03555a927c8 Mon Sep 17 00:00:00 2001 From: Mengwei Liu Date: Tue, 12 Nov 2024 16:18:57 -0800 Subject: [PATCH 2/2] Use assert_close Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: --- extension/llm/modules/test/test_attention.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/extension/llm/modules/test/test_attention.py b/extension/llm/modules/test/test_attention.py index 3c043c130b1..70267eb7c41 100644 --- a/extension/llm/modules/test/test_attention.py +++ b/extension/llm/modules/test/test_attention.py @@ -167,9 +167,9 @@ def test_attention_aoti(self): path = package_aoti(os.path.join(tempdir, "mha.pt2"), so) mha_aoti = load_package(path) - et_res = mha_aoti(self.x, self.x, input_pos=self.input_pos) + aoti_res = mha_aoti(self.x, self.x, input_pos=self.input_pos) tt_res = self.tt_mha(self.x, self.x, input_pos=self.input_pos) - self.assertTrue(torch.allclose(et_res, tt_res)) + assert_close(aoti_res, tt_res) def test_attention_executorch(self): # Self attention.