Skip to content
Merged
25 changes: 12 additions & 13 deletions beginner_source/introyt/introyt1_tutorial.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,22 +303,21 @@ def num_flat_features(self, x):
# The values passed to the transform are the means (first tuple) and the
# standard deviations (second tuple) of the rgb values of the images in
# the dataset. You can calculate these values yourself by running these
# few lines of code:
# ```
# from torch.utils.data import ConcatDataset
# transform = transforms.Compose([transforms.ToTensor()])
# trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
# few lines of code::
#
# from torch.utils.data import ConcatDataset
# transform = transforms.Compose([transforms.ToTensor()])
# trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
# download=True, transform=transform)
#
# #stack all train images together into a tensor of shape
# #(50000, 3, 32, 32)
# x = torch.stack([sample[0] for sample in ConcatDataset([trainset])])
# # stack all train images together into a tensor of shape
# # (50000, 3, 32, 32)
# x = torch.stack([sample[0] for sample in ConcatDataset([trainset])])
#
# #get the mean of each channel
# mean = torch.mean(x, dim=(0,2,3)) #tensor([0.4914, 0.4822, 0.4465])
# std = torch.std(x, dim=(0,2,3)) #tensor([0.2470, 0.2435, 0.2616])
#
# ```
# # get the mean of each channel
# mean = torch.mean(x, dim=(0,2,3)) # tensor([0.4914, 0.4822, 0.4465])
# std = torch.std(x, dim=(0,2,3)) # tensor([0.2470, 0.2435, 0.2616])
#
#
# There are many more transforms available, including cropping, centering,
# rotation, and reflection.
Expand Down