Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions ruff.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ include = [
"torchao/profiler/**/*.py",
"torchao/testing/**/*.py",
"torchao/_models/**/*.py",
"torchao/kernel/**/*.py",
"torchao/prototype/low_bit_optim/**.py",
"torchao/utils.py",
"torchao/ops.py",
Expand Down
3 changes: 1 addition & 2 deletions torchao/kernel/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from torchao.kernel.intmm import int_scaled_matmul
from torchao.kernel.intmm import safe_int_mm
from torchao.kernel.intmm import int_scaled_matmul, safe_int_mm

__all__ = [
"safe_int_mm",
Expand Down
3 changes: 1 addition & 2 deletions torchao/kernel/autotuner.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import logging
import os
import pathlib
import pickle

import torch
import triton
Expand Down Expand Up @@ -173,7 +172,7 @@ def wrapped_fn():
# Run it once and skip if it crashes or is 100x slower
try:
time = do_bench_basic(wrapped_fn, 1)
except RuntimeError as e:
except RuntimeError:
time = None
except triton.runtime.OutOfResources:
time = None
Expand Down
26 changes: 18 additions & 8 deletions torchao/kernel/intmm.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import itertools
import os

import torch

from torchao.utils import TORCH_VERSION_AT_LEAST_2_2, TORCH_VERSION_AT_LEAST_2_6
Expand All @@ -21,6 +21,7 @@
if TORCH_VERSION_AT_LEAST_2_2:
from torch._dynamo import is_compiling as dynamo_is_compiling
from torch._higher_order_ops.out_dtype import out_dtype

def safe_int_mm(input: torch.Tensor, mat2: torch.Tensor) -> torch.Tensor:
"""
Performs a safe integer matrix multiplication, considering different paths for
Expand All @@ -40,7 +41,9 @@ def safe_int_mm(input: torch.Tensor, mat2: torch.Tensor) -> torch.Tensor:
if dynamo_is_compiling() or "FakeTensor" in input.__repr__():
if input.device.type == "cpu":
# Matmul in int32 is slow on CPU and not supported well by Inductor cpp backend
return out_dtype(torch.ops.aten.mm.default, torch.int32, input.float(), mat2.float())
return out_dtype(
torch.ops.aten.mm.default, torch.int32, input.float(), mat2.float()
)
return out_dtype(torch.ops.aten.mm.default, torch.int32, input, mat2)

# error checking for cublas path
Expand All @@ -60,9 +63,9 @@ def safe_int_mm(input: torch.Tensor, mat2: torch.Tensor) -> torch.Tensor:

if device_cpu or bad_dimensions_for_cublas:
# fallback path
return torch.matmul(input.cpu().to(torch.int32), mat2.cpu().to(torch.int32)).to(
input.device.type
)
return torch.matmul(
input.cpu().to(torch.int32), mat2.cpu().to(torch.int32)
).to(input.device.type)

# cublas paths
if not mat2.is_contiguous(): # silently gives incorrect result without this
Expand All @@ -78,8 +81,11 @@ def safe_int_mm(input: torch.Tensor, mat2: torch.Tensor) -> torch.Tensor:
except Exception:
# fallback path, would run on H100 for float8 dtypes
# Exception on H100 float8 dtype : "addmm_cuda" not implemented for 'Float8_e4m3fn'
return torch.matmul(input.to(torch.float32), mat2.to(torch.float32)).to(torch.int32)
return torch.matmul(input.to(torch.float32), mat2.to(torch.float32)).to(
torch.int32
)
else:

def safe_int_mm(input: torch.Tensor, mat2: torch.Tensor) -> torch.Tensor:
"""
Performs a fallback integer matrix multiplication for torch versions before 2.2.
Expand All @@ -93,7 +99,9 @@ def safe_int_mm(input: torch.Tensor, mat2: torch.Tensor) -> torch.Tensor:
"""
# We can improve on this by writing Triton code that works for older versions of Triton
# that ship with 2.1 or 2.0.
return torch.matmul(input.to(torch.float32), mat2.to(torch.float32)).to(torch.int32)
return torch.matmul(input.to(torch.float32), mat2.to(torch.float32)).to(
torch.int32
)


def int_matmul(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
Expand All @@ -113,7 +121,9 @@ def int_matmul(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
return safe_int_mm(a, b)


def int_scaled_matmul(a: torch.Tensor, b: torch.Tensor, scales1: torch.Tensor) -> torch.Tensor:
def int_scaled_matmul(
a: torch.Tensor, b: torch.Tensor, scales1: torch.Tensor
) -> torch.Tensor:
"""
Performs scaled integer matrix multiplication.

Expand Down
50 changes: 22 additions & 28 deletions torchao/kernel/intmm_triton.py
Original file line number Diff line number Diff line change
@@ -1,42 +1,36 @@
import itertools
import os

import torch

import triton
import triton.language as tl

from torchao.kernel.autotuner import get_best_config_fn
from torchao.utils import TORCH_VERSION_AFTER_2_5

# TORCHINDUCTOR_MAX_AUTOTUNE_GEMM_SEARCH_SPACE=EXHAUSTIVE to enable exhaustive option
int8_mm_kernel_configs = (
sum(
int8_mm_kernel_configs = sum(
[
# "BLOCK_M", "BLOCK_N", "BLOCK_K", "num_stages", "num_warps"
[
# "BLOCK_M", "BLOCK_N", "BLOCK_K", "num_stages", "num_warps"
[
(i, j, k, 1, 1),
(i, j, k, 1, 2),
(i, j, k, 2, 2),
(i, j, k, 1, 4),
(i, j, k, 2, 4),
(i, j, k, 3, 4),
(i, j, k, 4, 4),
(i, j, k, 1, 8),
(i, j, k, 2, 8),
(i, j, k, 3, 8),
(i, j, k, 4, 8),
(i, j, k, 5, 8),
(i, j, k, 6, 8),
(i, j, k, 7, 8),
(i, j, k, 8, 8),
]
for (i, j, k) in itertools.product(
[32, 64, 128, 256], repeat=3
)
],
[]
)
(i, j, k, 1, 1),
(i, j, k, 1, 2),
(i, j, k, 2, 2),
(i, j, k, 1, 4),
(i, j, k, 2, 4),
(i, j, k, 3, 4),
(i, j, k, 4, 4),
(i, j, k, 1, 8),
(i, j, k, 2, 8),
(i, j, k, 3, 8),
(i, j, k, 4, 8),
(i, j, k, 5, 8),
(i, j, k, 6, 8),
(i, j, k, 7, 8),
(i, j, k, 8, 8),
]
for (i, j, k) in itertools.product([32, 64, 128, 256], repeat=3)
],
[],
)

if TORCH_VERSION_AFTER_2_5:
Expand Down
Loading