Skip to content

Commit ca52cdc

Browse files
authored
[Experimental] Add Kleidi i8mm gemm kernels (#1295)
* Update git ignore * [experimental] Add Kleidi compile def at the top level * [Experimental] Add Kleidi i8mm gemm kernels Add kernel level tests, with basic cross compilation support. Tested with S24 + r26c ``` [----------] 6 tests from test_kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm [ RUN ] test_kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm.k_eq_gs_32 [ OK ] test_kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm.k_eq_gs_32 (0 ms) [ RUN ] test_kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm.large_k_n_gs32 [ OK ] test_kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm.large_k_n_gs32 (79 ms) [ RUN ] test_kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm.even_n_gs32 [ OK ] test_kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm.even_n_gs32 (28 ms) [ RUN ] test_kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm.k_eq_gs128 [ OK ] test_kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm.k_eq_gs128 (3 ms) [ RUN ] test_kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm.clamp_k_eq_gs128 [ OK ] test_kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm.clamp_k_eq_gs128 (3 ms) [ RUN ] test_kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm.m_clamp_k_eq_gs128 [ OK ] test_kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm.m_clamp_k_eq_gs128 (5 ms) [----------] 6 tests from test_kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm (121 ms total) [----------] 6 tests from test_kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm [ RUN ] test_kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm.k_eq_gs_32 [ OK ] test_kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm.k_eq_gs_32 (0 ms) [ RUN ] test_kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm.large_k_n_gs32 [ OK ] test_kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm.large_k_n_gs32 (79 ms) [ RUN ] test_kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm.even_n_gs32 [ OK ] test_kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm.even_n_gs32 (28 ms) [ RUN ] test_kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm.k_eq_gs128 [ OK ] test_kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm.k_eq_gs128 (3 ms) [ RUN ] test_kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm.clamp_k_eq_gs128 [ OK ] test_kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm.clamp_k_eq_gs128 (3 ms) [ RUN ] test_kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm.m_clamp_k_eq_gs128 [ OK ] test_kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm.m_clamp_k_eq_gs128 (5 ms) [----------] 6 tests from test_kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm (121 ms total) ``` * [Exeprimental] Kleidi: rename arg name for packing functions * [Experimental] Change kernel cmake_out dir to avoid conflict
1 parent 72fb597 commit ca52cdc

10 files changed

+557
-17
lines changed

.gitignore

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -371,4 +371,7 @@ venv/
371371
sweep/
372372

373373
# Model checkpoints
374-
checkpoints/
374+
checkpoints/
375+
376+
# Experimental
377+
torchao/experimental/cmake-out

torchao/experimental/CMakeLists.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,10 @@ if(NOT TORCHAO_INCLUDE_DIRS)
2323
endif()
2424

2525
option(TORCHAO_BUILD_KLEIDIAI "Download, build, and link against Arm KleidiAI library (arm64 only)" OFF)
26+
if(TORCHAO_BUILD_KLEIDIAI)
27+
message(STATUS "Building with Arm KleidiAI library")
28+
add_compile_definitions(TORCHAO_ENABLE_KLEIDI=1)
29+
endif()
2630
include(CMakePrintHelpers)
2731

2832
add_compile_options("-Wall" "-Werror" "-Wno-deprecated")

torchao/experimental/kernels/cpu/aarch64/CMakeLists.txt

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,7 @@
44
# This source code is licensed under the license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
8-
if (CMAKE_SYSTEM_PROCESSOR STREQUAL "arm64")
7+
if ((CMAKE_SYSTEM_PROCESSOR STREQUAL "arm64") OR (CMAKE_SYSTEM_PROCESSOR STREQUAL "aarch64"))
98
add_library(
109
torchao_kernels_aarch64
1110
${CMAKE_CURRENT_SOURCE_DIR}/reduction/find_min_and_max.cpp
@@ -25,7 +24,7 @@ if (CMAKE_SYSTEM_PROCESSOR STREQUAL "arm64")
2524

2625
# Temporarily exposing this to the parent scope until we wire
2726
# this up properly from the top level
28-
set(TORCHAO_ENABLE_KLEIDI ON PARENT_SCOPE)
27+
set(TORCHAO_BUILD_KLEIDI ON PARENT_SCOPE)
2928
target_link_libraries(torchao_kernels_aarch64 PUBLIC kleidiai)
3029
endif()
3130
endif()

torchao/experimental/kernels/cpu/aarch64/kleidi/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -47,14 +47,14 @@ size_t activation_data_size(int m, int k, int group_size) {
4747
}
4848

4949
void prepare_activation_data(
50-
void* activation_data,
50+
void* prepared_activation_data,
5151
int m,
5252
int k,
5353
int group_size,
5454
const float* activations) {
5555
(void)group_size; // unused
5656
kai_matmul_clamp_f32_qai8dxp_qsi4c32p::prepare_activation_data(
57-
get_ukernel(), activation_data, m, k, activations);
57+
get_ukernel(), prepared_activation_data, m, k, activations);
5858
}
5959

6060
size_t weight_data_size(int n, int k, int group_size) {
@@ -63,7 +63,7 @@ size_t weight_data_size(int n, int k, int group_size) {
6363
}
6464

6565
void prepare_weight_data(
66-
void* weight_data,
66+
void* prepared_weight_data,
6767
int n,
6868
int k,
6969
int group_size,
@@ -73,7 +73,7 @@ void prepare_weight_data(
7373
const float* bias) {
7474
kai_matmul_clamp_f32_qai8dxp_qsi4c32p::prepare_weight_data(
7575
get_ukernel(),
76-
weight_data,
76+
prepared_weight_data,
7777
n,
7878
k,
7979
group_size,

torchao/experimental/kernels/cpu/aarch64/kleidi/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -45,15 +45,15 @@ size_t activation_data_size(int m, int k, int group_size) {
4545
}
4646

4747
void prepare_activation_data(
48-
void* activation_data,
48+
void* prepared_activation_data,
4949
int m,
5050
int k,
5151
int group_size,
5252
const float* activations) {
5353
(void) group_size; // unused
5454
kai_matmul_clamp_f32_qai8dxp_qsi4c32p::prepare_activation_data(
5555
get_ukernel(),
56-
activation_data,
56+
prepared_activation_data,
5757
m,
5858
k,
5959
activations);
@@ -64,7 +64,7 @@ size_t weight_data_size(int n, int k, int group_size) {
6464
}
6565

6666
void prepare_weight_data(
67-
void* weight_data,
67+
void* prepared_weight_data,
6868
int n,
6969
int k,
7070
int group_size,
@@ -74,7 +74,7 @@ void prepare_weight_data(
7474
const float* bias) {
7575
kai_matmul_clamp_f32_qai8dxp_qsi4c32p::prepare_weight_data(
7676
get_ukernel(),
77-
weight_data,
77+
prepared_weight_data,
7878
n,
7979
k,
8080
group_size,
Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
// Copyright (c) Meta Platforms, Inc. and affiliates.
2+
// All rights reserved.
3+
//
4+
// This source code is licensed under the license found in the
5+
// LICENSE file in the root directory of this source tree.
6+
7+
#pragma once
8+
#include <kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm.h>
9+
#include <torchao/experimental/kernels/cpu/aarch64/kleidi/kai_matmul_clamp_f32_qai8dxp_qsi4c32p.h>
10+
11+
namespace torchao::kernels::cpu::aarch64::kleidi {
12+
namespace kai_matmul_clamp_f32_qai8dxp_qsi4c32p {
13+
namespace neon_i8mm_8x4x32 {
14+
15+
const Ukernel get_ukernel() {
16+
return Ukernel{
17+
.get_m_step =
18+
kai_get_m_step_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm,
19+
.get_n_step =
20+
kai_get_n_step_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm,
21+
.get_mr =
22+
kai_get_mr_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm,
23+
.get_nr =
24+
kai_get_nr_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm,
25+
.get_kr =
26+
kai_get_kr_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm,
27+
.get_sr =
28+
kai_get_sr_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm,
29+
.get_lhs_packed_offset =
30+
kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm,
31+
.get_rhs_packed_offset =
32+
kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm,
33+
.get_dst_offset =
34+
kai_get_dst_offset_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm,
35+
.get_dst_size =
36+
kai_get_dst_size_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm,
37+
.run_matmul =
38+
kai_run_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm};
39+
}
40+
41+
size_t activation_data_size(int m, int k, int group_size) {
42+
(void)group_size; // unused
43+
return kai_matmul_clamp_f32_qai8dxp_qsi4c32p::activation_data_size(
44+
get_ukernel(), m, k);
45+
}
46+
47+
void prepare_activation_data(
48+
void* prepared_activation_data,
49+
int m,
50+
int k,
51+
int group_size,
52+
const float* activations) {
53+
(void)group_size; // unused
54+
kai_matmul_clamp_f32_qai8dxp_qsi4c32p::prepare_activation_data(
55+
get_ukernel(), prepared_activation_data, m, k, activations);
56+
}
57+
58+
size_t weight_data_size(int n, int k, int group_size) {
59+
return kai_matmul_clamp_f32_qai8dxp_qsi4c32p::weight_data_size(
60+
get_ukernel(), n, k, group_size);
61+
}
62+
63+
void prepare_weight_data(
64+
void* prepared_weight_data,
65+
int n,
66+
int k,
67+
int group_size,
68+
const int8_t* weight_qvals,
69+
const float* weight_scales,
70+
const int8_t* weight_zeros,
71+
const float* bias) {
72+
kai_matmul_clamp_f32_qai8dxp_qsi4c32p::prepare_weight_data(
73+
get_ukernel(),
74+
prepared_weight_data,
75+
n,
76+
k,
77+
group_size,
78+
weight_qvals,
79+
weight_scales,
80+
weight_zeros,
81+
bias);
82+
}
83+
84+
void kernel(
85+
float32_t* output,
86+
int output_m_stride,
87+
int m,
88+
int n,
89+
int k,
90+
int group_size,
91+
const void* weight_data,
92+
const void* activation_data,
93+
float clamp_min,
94+
float clamp_max) {
95+
if (clamp_min == 0 && clamp_max == 0) {
96+
clamp_min = std::numeric_limits<float>::lowest();
97+
clamp_max = std::numeric_limits<float>::max();
98+
}
99+
100+
auto ukernel = get_ukernel();
101+
ukernel.run_matmul(
102+
m,
103+
n,
104+
k,
105+
group_size,
106+
activation_data,
107+
weight_data,
108+
output,
109+
/*dst_stride_row=*/n * sizeof(float),
110+
/*dst_stride_col=*/sizeof(float),
111+
clamp_min,
112+
clamp_max);
113+
}
114+
115+
size_t get_preferred_alignement() {
116+
return 16;
117+
}
118+
} // namespace neon_i8mm_8x4x32
119+
} // namespace kai_matmul_clamp_f32_qai8dxp_qsi4c32p
120+
} // namespace torchao::kernels::cpu::aarch64::kleidi
Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
// Copyright (c) Meta Platforms, Inc. and affiliates.
2+
// All rights reserved.
3+
//
4+
// This source code is licensed under the license found in the
5+
// LICENSE file in the root directory of this source tree.
6+
7+
#pragma once
8+
#include <kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm.h>
9+
10+
#include <torchao/experimental/kernels/cpu/aarch64/kleidi/kai_matmul_clamp_f32_qai8dxp_qsi4c32p.h>
11+
12+
namespace torchao::kernels::cpu::aarch64::kleidi {
13+
namespace kai_matmul_clamp_f32_qai8dxp_qsi4c32p {
14+
namespace neon_i8mm_4x8x32 {
15+
16+
const Ukernel get_ukernel() {
17+
return Ukernel{
18+
.get_m_step =
19+
kai_get_m_step_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm,
20+
.get_n_step =
21+
kai_get_n_step_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm,
22+
.get_mr =
23+
kai_get_mr_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm,
24+
.get_nr =
25+
kai_get_nr_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm,
26+
.get_kr =
27+
kai_get_kr_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm,
28+
.get_sr =
29+
kai_get_sr_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm,
30+
.get_lhs_packed_offset =
31+
kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm,
32+
.get_rhs_packed_offset =
33+
kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm,
34+
.get_dst_offset =
35+
kai_get_dst_offset_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm,
36+
.get_dst_size =
37+
kai_get_dst_size_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm,
38+
.run_matmul =
39+
kai_run_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm};
40+
}
41+
42+
size_t activation_data_size(int m, int k, int group_size) {
43+
(void)group_size; // unused
44+
return kai_matmul_clamp_f32_qai8dxp_qsi4c32p::activation_data_size(
45+
get_ukernel(), m, k);
46+
}
47+
48+
void prepare_activation_data(
49+
void* prepared_activation_data,
50+
int m,
51+
int k,
52+
int group_size,
53+
const float* activations) {
54+
(void)group_size; // unused
55+
kai_matmul_clamp_f32_qai8dxp_qsi4c32p::prepare_activation_data(
56+
get_ukernel(), prepared_activation_data, m, k, activations);
57+
}
58+
59+
size_t weight_data_size(int n, int k, int group_size) {
60+
return kai_matmul_clamp_f32_qai8dxp_qsi4c32p::weight_data_size(
61+
get_ukernel(), n, k, group_size);
62+
}
63+
64+
void prepare_weight_data(
65+
void* prepared_weight_data,
66+
int n,
67+
int k,
68+
int group_size,
69+
const int8_t* weight_qvals,
70+
const float* weight_scales,
71+
const int8_t* weight_zeros,
72+
const float* bias) {
73+
kai_matmul_clamp_f32_qai8dxp_qsi4c32p::prepare_weight_data(
74+
get_ukernel(),
75+
prepared_weight_data,
76+
n,
77+
k,
78+
group_size,
79+
weight_qvals,
80+
weight_scales,
81+
weight_zeros,
82+
bias);
83+
}
84+
85+
void kernel(
86+
float32_t* output,
87+
int output_m_stride,
88+
int m,
89+
int n,
90+
int k,
91+
int group_size,
92+
const void* weight_data,
93+
const void* activation_data,
94+
float clamp_min,
95+
float clamp_max) {
96+
if (clamp_min == 0 && clamp_max == 0) {
97+
clamp_min = std::numeric_limits<float>::lowest();
98+
clamp_max = std::numeric_limits<float>::max();
99+
}
100+
101+
auto ukernel = get_ukernel();
102+
ukernel.run_matmul(
103+
m,
104+
n,
105+
k,
106+
group_size,
107+
activation_data,
108+
weight_data,
109+
output,
110+
/*dst_stride_row=*/n * sizeof(float),
111+
/*dst_stride_col=*/sizeof(float),
112+
clamp_min,
113+
clamp_max);
114+
}
115+
116+
size_t get_preferred_alignement() {
117+
return 16;
118+
}
119+
120+
} // namespace neon_i8mm_4x8x32
121+
} // namespace kai_matmul_clamp_f32_qai8dxp_qsi4c32p
122+
} // namespace torchao::kernels::cpu::aarch64::kleidi

torchao/experimental/kernels/cpu/aarch64/tests/CMakeLists.txt

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,11 @@ FetchContent_Declare(
1515
)
1616
FetchContent_MakeAvailable(googletest)
1717

18+
if (ANDROID_ABI)
19+
# We are cross compiling, delay test discovery till runtime
20+
set(CMAKE_GTEST_DISCOVER_TESTS_DISCOVERY_MODE PRE_TEST)
21+
endif()
22+
1823
add_compile_options("-Wall" "-Werror")
1924

2025
include(CMakePrintHelpers)
@@ -35,13 +40,29 @@ endif()
3540

3641
add_subdirectory(${TORCHAO_LIBRARIES}/torchao/experimental/kernels/cpu/aarch64 ${CMAKE_CURRENT_BINARY_DIR}/torchao_kernels_aarch64)
3742

38-
# The TORCHAO_ENABLE_KLEIDI cmake variable should be set by `torchao_kernels_aarch64"
39-
if(TORCHAO_ENABLE_KLEIDI)
43+
# The TORCHAO_BUILD_KLEIDI cmake variable should be set by `torchao_kernels_aarch64"
44+
if(TORCHAO_BUILD_KLEIDI)
4045
add_compile_definitions(TORCHAO_ENABLE_KLEIDI)
4146
endif()
4247

48+
if(TORCHAO_BUILD_ARM_I8MM)
49+
add_compile_definitions(TORCHAO_ENABLE_ARM_I8MM)
50+
endif()
51+
4352
enable_testing()
4453

54+
if (ANDROID_ABI)
55+
# Given where we are today this is sufficent. But needs to be revisited.
56+
# This is also needed for native builds, but keeping it only for cross builds
57+
# for now given the hacky nature.
58+
file(GLOB DOTPROD_SRC_FILES test*.cpp)
59+
message(SRC_FILES: ${DOTPROD_SRC_FILES})
60+
set_property(SOURCE
61+
${DOTPROD_SRC_FILES}
62+
APPEND_STRING PROPERTY
63+
COMPILE_FLAGS " -march=armv8.2-a+dotprod ")
64+
endif()
65+
4566
add_executable(test_quantization test_quantization.cpp)
4667
target_link_libraries(
4768
test_quantization

0 commit comments

Comments
 (0)