From 458fa7d30dc22f2be7b3d04a44b7f32c4d8a703d Mon Sep 17 00:00:00 2001 From: NihalHarish Date: Wed, 28 Oct 2020 04:07:39 -0700 Subject: [PATCH 1/2] Revert "pytorch tmp (#382)" This reverts commit ef3b6b240357f593ad32aa4f781b03e9b0cc4be9. --- tests/zero_code_change/pt_utils.py | 16 ++-------------- 1 file changed, 2 insertions(+), 14 deletions(-) diff --git a/tests/zero_code_change/pt_utils.py b/tests/zero_code_change/pt_utils.py index 9bb1c073e..3d6cb78de 100644 --- a/tests/zero_code_change/pt_utils.py +++ b/tests/zero_code_change/pt_utils.py @@ -7,7 +7,6 @@ import torch.nn.functional as F import torchvision import torchvision.transforms as transforms -from packaging import version def get_dataloaders() -> Tuple[torch.utils.data.DataLoader, torch.utils.data.DataLoader]: @@ -15,26 +14,15 @@ def get_dataloaders() -> Tuple[torch.utils.data.DataLoader, torch.utils.data.Dat [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))] ) - # Temporary Change to allow the test to run with pytorch 1.7 RC3 - # Smdebug breaks when num_workers>0 for Pytorch 1.7.0 - if version.parse(torch.__version__) >= version.parse("1.7.0"): - num_workers = 0 - else: - num_workers = 2 - trainset = torchvision.datasets.CIFAR10( root="./data", train=True, download=True, transform=transform ) - trainloader = torch.utils.data.DataLoader( - trainset, batch_size=4, shuffle=True, num_workers=num_workers - ) + trainloader = torch.utils.data.DataLoader(trainset, batch_size=4, shuffle=True, num_workers=2) testset = torchvision.datasets.CIFAR10( root="./data", train=False, download=True, transform=transform ) - testloader = torch.utils.data.DataLoader( - testset, batch_size=4, shuffle=False, num_workers=num_workers - ) + testloader = torch.utils.data.DataLoader(testset, batch_size=4, shuffle=False, num_workers=2) classes = ("plane", "car", "bird", "cat", "deer", "dog", "frog", "horse", "ship", "truck") return trainloader, testloader From 30b3a2c6c3a322554603a2605c9d9a54476b4307 Mon Sep 17 00:00:00 2001 From: NihalHarish Date: Wed, 28 Oct 2020 04:07:50 -0700 Subject: [PATCH 2/2] Revert "disable pytorch (#386)" This reverts commit 311a6f451eafd56e121e5eb784fcf9fc3b269db4. --- tests/zero_code_change/test_pytorch_integration.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/tests/zero_code_change/test_pytorch_integration.py b/tests/zero_code_change/test_pytorch_integration.py index 21e7759f8..eb6d06536 100644 --- a/tests/zero_code_change/test_pytorch_integration.py +++ b/tests/zero_code_change/test_pytorch_integration.py @@ -12,7 +12,6 @@ # Third Party import pytest -import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim @@ -23,10 +22,6 @@ from smdebug.core.utils import SagemakerSimulator, ScriptSimulator -@pytest.mark.skipif( - torch.__version__ == "1.7.0", - reason="Disabling the test temporarily until we root cause the version incompatibility", -) @pytest.mark.parametrize("script_mode", [False]) @pytest.mark.parametrize("use_loss_module", [True, False]) def test_pytorch(script_mode, use_loss_module):