diff --git a/torchao/experimental/kernels/mps/codegen/gen_metal_shader_lib.py b/torchao/experimental/kernels/mps/codegen/gen_metal_shader_lib.py index dc6d9267af..56dcc9b730 100644 --- a/torchao/experimental/kernels/mps/codegen/gen_metal_shader_lib.py +++ b/torchao/experimental/kernels/mps/codegen/gen_metal_shader_lib.py @@ -28,7 +28,7 @@ * This file is generated by gen_metal_shader_lib.py */ -#ifdef ATEN +#ifdef USE_ATEN using namespace at::native::mps; #else #include diff --git a/torchao/experimental/kernels/mps/src/lowbit.h b/torchao/experimental/kernels/mps/src/lowbit.h index d10d00c284..d37001350a 100644 --- a/torchao/experimental/kernels/mps/src/lowbit.h +++ b/torchao/experimental/kernels/mps/src/lowbit.h @@ -17,7 +17,7 @@ #include #include -#ifdef ATEN +#ifdef USE_ATEN #include using namespace at::native::mps; inline void finalize_block(MPSStream* mpsStream) {} diff --git a/torchao/experimental/ops/mps/setup.py b/torchao/experimental/ops/mps/setup.py index e9c206cdb9..1205d43d45 100644 --- a/torchao/experimental/ops/mps/setup.py +++ b/torchao/experimental/ops/mps/setup.py @@ -16,7 +16,7 @@ name="torchao_mps_ops", sources=["register.mm"], include_dirs=[os.getenv("TORCHAO_ROOT")], - extra_compile_args=["-DATEN=1"], + extra_compile_args=["-DUSE_ATEN=1"], ), ], cmdclass={"build_ext": BuildExtension},