@@ -304,15 +304,19 @@ def print_size_of_model(model):
304304# ----------------------------------
305305#
306306# As our last major setup step, we define our dataloaders for our training and testing set.
307- # The specific dataset we've created for this tutorial contains just 1000 images, one from
307+ #
308+ # ImageNet Data
309+ # ^^^^^^^^^^^^^
310+ #
311+ # The specific dataset we've created for this tutorial contains just 1000 images from the ImageNet data, one from
308312# each class (this dataset, at just over 250 MB, is small enough that it can be downloaded
309313# relatively easily). The URL for this custom dataset is:
310314#
311315# .. code::
312316#
313317# https://s3.amazonaws.com/pytorch-tutorial-assets/imagenet_1k.zip
314318#
315- # To download this data locally using Python, then, you could use:
319+ # To download this data locally using Python, you could use:
316320#
317321# .. code:: python
318322#
@@ -326,11 +330,32 @@ def print_size_of_model(model):
326330# with open(filename, 'wb') as f:
327331# f.write(r.content)
328332#
329- #
330333# For this tutorial to run, we download this data and move it to the right place using
331334# `these lines <https://github.com/pytorch/tutorials/blob/master/Makefile#L97-L98>`_
332335# from the `Makefile <https://github.com/pytorch/tutorials/blob/master/Makefile>`_.
333336#
337+ # To run the code in this tutorial using the entire ImageNet dataset, on the other hand, you could download
338+ # the data using ``torchvision`` following
339+ # `here <https://pytorch.org/docs/stable/torchvision/datasets.html#imagenet>`_. For example,
340+ # to download the training set and apply some standard transformations to it, you could use:
341+ #
342+ # .. code:: python
343+ #
344+ # import torchvision
345+ # import torchvision.transforms as transforms
346+ #
347+ # imagenet_dataset = torchvision.datasets.ImageNet(
348+ # '~/.data/imagenet',
349+ # split='train',
350+ # download=True,
351+ # transforms.Compose([
352+ # transforms.RandomResizedCrop(224),
353+ # transforms.RandomHorizontalFlip(),
354+ # transforms.ToTensor(),
355+ # transforms.Normalize(mean=[0.485, 0.456, 0.406],
356+ # std=[0.229, 0.224, 0.225]),
357+ # ])
358+ #
334359# With the data downloaded, we show functions below that define dataloaders we'll use to read
335360# in this data. These functions mostly come from
336361# `here <https://github.com/pytorch/vision/blob/master/references/detection/train.py>`_.
@@ -374,12 +399,12 @@ def prepare_data_loaders(data_path):
374399 return data_loader , data_loader_test
375400
376401######################################################################
377- # Next, we'll load in the pre-trained MobileNetV2 model. Similarly to the data about, the file with the pre-trained
378- # weights is stored at `` https://s3.amazonaws. com/pytorch-tutorial-assets/mobilenet_quantization.pth``:
402+ # Next, we'll load in the pre-trained MobileNetV2 model. We provide the URL to download the data from in ``torchvision``
403+ # `here < https://github. com/pytorch/vision/blob/master/torchvision/models/mobilenet.py#L9>`_.
379404
380405data_path = 'data/imagenet_1k'
381406saved_model_dir = 'data/'
382- float_model_file = 'mobilenet_quantization .pth'
407+ float_model_file = 'mobilenet_pretrained_float .pth'
383408scripted_float_model_file = 'mobilenet_quantization_scripted.pth'
384409scripted_quantized_model_file = 'mobilenet_quantization_scripted_quantized.pth'
385410
0 commit comments