File tree Expand file tree Collapse file tree 4 files changed +10
-12
lines changed
torchao/prototype/mx_formats Expand file tree Collapse file tree 4 files changed +10
-12
lines changed Original file line number Diff line number Diff line change @@ -9,4 +9,12 @@ set -eux
99
1010echo " This script is run before building torchao binaries"
1111
12+ python -m pip install --upgrade pip
13+ if [ -z " $PYTORCH_VERSION " ]; then
14+ PYTORCH_DEP=" torch"
15+ else
16+ PYTORCH_DEP=" torch==$PYTORCH_VERSION "
17+ fi
18+ pip install $PYTORCH_DEP
19+
1220pip install setuptools wheel twine auditwheel
Original file line number Diff line number Diff line change @@ -110,15 +110,6 @@ def get_extensions():
110110
111111 return ext_modules
112112
113- # Mimic code from torchvision https://github.com/pytorch/vision/blob/143d078b28f00471156a4e562dd3836370acc9ee/setup.py#L58
114- pytorch_dep = "torch"
115- if os .getenv ("PYTORCH_VERSION" ):
116- pytorch_dep += "==" + os .getenv ("PYTORCH_VERSION" )
117-
118- requirements = [
119- pytorch_dep ,
120- ]
121-
122113setup (
123114 name = package_name ,
124115 version = version + version_suffix ,
@@ -128,7 +119,6 @@ def get_extensions():
128119 "torchao.kernel.configs" : ["*.pkl" ],
129120 },
130121 ext_modules = get_extensions () if use_cpp != "0" else None ,
131- install_requires = requirements ,
132122 extras_require = {"dev" : read_requirements ("dev-requirements.txt" )},
133123 description = "Package for applying ao techniques to GPU models" ,
134124 long_description = open ("README.md" ).read (),
Original file line number Diff line number Diff line change 1515# TODO(future): if needed, make the below work on previous PyTorch versions,
1616# just need to hunt down the previous location of `libdevice`. An assert
1717# at the callsite prevents usage of this on unsupported versions.
18- if TORCH_VERSION_AFTER_2_4 :
18+ if TORCH_VERSION_AFTER_2_4 and has_triton () :
1919 from torch ._inductor .runtime .triton_helpers import libdevice
2020
2121from torchao .prototype .mx_formats .constants import (
Original file line number Diff line number Diff line change 1- 0.3.0
1+ 0.3.1
You can’t perform that action at this time.
0 commit comments