From 4c6e8489f77438e3452fb464e64623f8be5188c8 Mon Sep 17 00:00:00 2001 From: Huy Do Date: Fri, 10 Nov 2023 16:47:24 -0800 Subject: [PATCH 1/4] Add some more validation checks for torch.linalg.eigh and torch.compile --- check_binary.sh | 6 ++++++ test_example_code/torch_compile_smoke.py | 12 ++++++++++++ 2 files changed, 18 insertions(+) create mode 100644 test_example_code/torch_compile_smoke.py diff --git a/check_binary.sh b/check_binary.sh index 30b44b535..42ee0e997 100755 --- a/check_binary.sh +++ b/check_binary.sh @@ -404,6 +404,12 @@ if [[ "$DESIRED_CUDA" != 'cpu' && "$DESIRED_CUDA" != 'cpu-cxx11-abi' && "$DESIRE echo "Test that linalg works" python -c "import torch;x=torch.rand(3,3,device='cuda');print(torch.linalg.svd(torch.mm(x.t(), x)))" + echo "Test that linalg.eigh works" + python -c "import torch;x=torch.rand(3,3,device='cuda');print(torch.linalg.eigh(x))" + + echo "Checking that basic torch.compile works" + python ${TEST_CODE_DIR}/torch_compile_smoke.py + popd fi # if libtorch fi # if cuda diff --git a/test_example_code/torch_compile_smoke.py b/test_example_code/torch_compile_smoke.py new file mode 100644 index 000000000..7a12a013e --- /dev/null +++ b/test_example_code/torch_compile_smoke.py @@ -0,0 +1,12 @@ +import torch + + +def foo(x: torch.Tensor) -> torch.Tensor: + return torch.sin(x) + torch.cos(x) + + +if __name__ == "__main__": + x = torch.rand(3, 3, device="cuda") + x_eager = foo(x) + x_pt2 = torch.compile(foo)(x) + print(torch.allclose(x_eager, x_pt2)) From 6d02a86e8734eb2130979f1f622117eb7a8c71bd Mon Sep 17 00:00:00 2001 From: Huy Do Date: Fri, 10 Nov 2023 20:42:42 -0800 Subject: [PATCH 2/4] Update test --- check_binary.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/check_binary.sh b/check_binary.sh index 42ee0e997..9e7d03a54 100755 --- a/check_binary.sh +++ b/check_binary.sh @@ -405,7 +405,7 @@ if [[ "$DESIRED_CUDA" != 'cpu' && "$DESIRED_CUDA" != 'cpu-cxx11-abi' && "$DESIRE python -c "import torch;x=torch.rand(3,3,device='cuda');print(torch.linalg.svd(torch.mm(x.t(), x)))" echo "Test that linalg.eigh works" - python -c "import torch;x=torch.rand(3,3,device='cuda');print(torch.linalg.eigh(x))" + python -c "import torch;x=torch.rand(3,3,device='cuda');print(torch.linalg.eigh(torch.mm(x.t(), x)))" echo "Checking that basic torch.compile works" python ${TEST_CODE_DIR}/torch_compile_smoke.py From 729b8457fcbd949e8e16330aae07d36946ed36f0 Mon Sep 17 00:00:00 2001 From: Huy Do Date: Wed, 15 Nov 2023 18:04:37 -0800 Subject: [PATCH 3/4] Also update smoke_test.py --- test/smoke_test/smoke_test.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/test/smoke_test/smoke_test.py b/test/smoke_test/smoke_test.py index 3d1b6af64..9224bb4c3 100644 --- a/test/smoke_test/smoke_test.py +++ b/test/smoke_test/smoke_test.py @@ -193,6 +193,9 @@ def smoke_test_linalg() -> None: A = torch.randn(20, 16, 50, 100, device="cuda").type(dtype) torch.linalg.svd(A) + A = torch.rand(3, 3, device="cuda"); + L, Q = torch.linalg.eigh(torch.mm(A.t(), A)) + def smoke_test_compile() -> None: supported_dtypes = [torch.float16, torch.float32, torch.float64] From cf3ccc8e300da1781fbf75f27c1e1ce8ce669562 Mon Sep 17 00:00:00 2001 From: Huy Do Date: Wed, 15 Nov 2023 18:09:20 -0800 Subject: [PATCH 4/4] Fix lint --- test/smoke_test/smoke_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/smoke_test/smoke_test.py b/test/smoke_test/smoke_test.py index 9224bb4c3..64efc7601 100644 --- a/test/smoke_test/smoke_test.py +++ b/test/smoke_test/smoke_test.py @@ -193,7 +193,7 @@ def smoke_test_linalg() -> None: A = torch.randn(20, 16, 50, 100, device="cuda").type(dtype) torch.linalg.svd(A) - A = torch.rand(3, 3, device="cuda"); + A = torch.rand(3, 3, device="cuda") L, Q = torch.linalg.eigh(torch.mm(A.t(), A))