Skip to content

Commit 79a7f52

Browse files
committed
Refactored files
1 parent 6fd77d5 commit 79a7f52

18 files changed

+1795
-1511
lines changed

test/dtypes/test_affine_quantized.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ def test_to_device(self, apply_quant):
9292

9393
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
9494
def test_register_new_dispatch(self):
95-
from torchao.dtypes.affine_quantized_tensor import (
95+
from torchao.dtypes.affine_quantized_tensor_ops import (
9696
register_aqt_quantized_linear_dispatch,
9797
deregister_aqt_quantized_linear_dispatch,
9898
)

torchao/dtypes/__init__.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from .nf4tensor import NF4Tensor, to_nf4
22
# from ..prototype.dtypes.uint2 import UInt2Tensor, BitnetTensor
3-
from .uint4 import UInt4Tensor
3+
from .uintx import UInt4Tensor
44
from .affine_quantized_tensor import (
55
AffineQuantizedTensor,
66
to_affine_quantized_intx,
@@ -9,15 +9,22 @@
99
to_affine_quantized_fpx,
1010
to_affine_quantized_floatx,
1111
to_affine_quantized_floatx_static,
12+
PlainAQTTensorImpl,
13+
)
14+
from .affine_quantized_tensor_ops import *
15+
from .utils import (
1216
Layout,
1317
PlainLayout,
14-
SemiSparseLayout,
15-
TensorCoreTiledLayout,
18+
)
19+
from .floatx import (
1620
Float8Layout,
1721
Float8AQTTensorImpl,
22+
)
23+
from .uintx import (
24+
SemiSparseLayout,
25+
TensorCoreTiledLayout,
1826
MarlinSparseLayout,
1927
)
20-
2128
__all__ = [
2229
"NF4Tensor",
2330
"to_nf4",

0 commit comments

Comments
 (0)