33Models and pre-trained weights - New
44####################################
55
6- .. note ::
7-
8- These are the new models docs, documenting the new multi-weight API.
9- TODO: Once all is done, remove the "- New" part in the title above, and
10- rename this file as models.rst
11-
12-
136The ``torchvision.models `` subpackage contains definitions of models for addressing
147different tasks, including: image classification, pixelwise semantic
158segmentation, object detection, instance segmentation, person
169keypoint detection, video classification, and optical flow.
1710
11+ General information on pre-trained weights
12+ ==========================================
13+
14+ TorchVision offers pre-trained weights for every provided architecture, using
15+ the PyTorch :mod: `torch.hub `. Instancing a pre-trained model will download its
16+ weights to a cache directory. This directory can be set using the `TORCH_HOME `
17+ environment variable. See :func: `torch.hub.load_state_dict_from_url ` for details.
18+
19+ .. note ::
20+
21+ The pre-trained models provided in this library may have their own licenses or
22+ terms and conditions derived from the dataset used for training. It is your
23+ responsibility to determine whether you have permission to use the models for
24+ your use case.
25+
1826.. note ::
19- Backward compatibility is guaranteed for loading a serialized
20- ``state_dict`` to the model created using old PyTorch version.
21- On the contrary, loading entire saved models or serialized
22- ``ScriptModules`` (seralized using older versions of PyTorch)
23- may not preserve the historic behaviour. Refer to the following
24- `documentation
25- <https://pytorch.org/docs/stable/notes/serialization.html#id6>`_
27+ Backward compatibility is guaranteed for loading a serialized
28+ ``state_dict`` to the model created using old PyTorch version.
29+ On the contrary, loading entire saved models or serialized
30+ ``ScriptModules`` (serialized using older versions of PyTorch)
31+ may not preserve the historic behaviour. Refer to the following
32+ `documentation
33+ <https://pytorch.org/docs/stable/notes/serialization.html#id6>`_
34+
35+
36+ Initializing pre-trained models
37+ -------------------------------
2638
2739As of v0.13, TorchVision offers a new `Multi-weight support API
28- <https://pytorch.org/blog/introducing-torchvision-new-multi-weight-support-api/> `_ for loading different weights to the
29- existing model builder methods:
40+ <https://pytorch.org/blog/introducing-torchvision-new-multi-weight-support-api/> `_
41+ for loading different weights to the existing model builder methods:
3042
3143.. code :: python
3244
@@ -46,7 +58,7 @@ existing model builder methods:
4658 resnet50(weights = " IMAGENET1K_V2" )
4759
4860 # No weights - random initialization
49- resnet50(weights = None ) # or resnet50()
61+ resnet50(weights = None )
5062
5163
5264 Migrating to the new API is very straightforward. The following method calls between the 2 APIs are all equivalent:
@@ -57,16 +69,57 @@ Migrating to the new API is very straightforward. The following method calls bet
5769
5870 # Using pretrained weights:
5971 resnet50(weights = ResNet50_Weights.IMAGENET1K_V1 )
72+ resnet50(weights = " IMAGENET1K_V1" )
6073 resnet50(pretrained = True ) # deprecated
6174 resnet50(True ) # deprecated
6275
6376 # Using no weights:
6477 resnet50(weights = None )
78+ resnet50()
6579 resnet50(pretrained = False ) # deprecated
6680 resnet50(False ) # deprecated
6781
6882 Note that the ``pretrained `` parameter is now deprecated, using it will emit warnings and will be removed on v0.15.
6983
84+ Using the pre-trained models
85+ ----------------------------
86+
87+ Before using the pre-trained models, one must preprocess the image
88+ (resize with right resolution/interpolation, apply inference transforms,
89+ rescale the values etc). There is no standard way to do this as it depends on
90+ how a given model was trained. It can vary across model families, variants or
91+ even weight versions. Using the correct preprocessing method is critical and
92+ failing to do so may lead to decreased accuracy or incorrect outputs.
93+
94+ All the necessary information for the inference transforms of each pre-trained
95+ model is provided on its weights documentation. To simplify inference, TorchVision
96+ bundles the necessary preprocessing transforms into each model weight. These are
97+ accessible via the ``weight.transforms `` attribute:
98+
99+ .. code :: python
100+
101+ # Initialize the Weight Transforms
102+ weights = ResNet50_Weights.DEFAULT
103+ preprocess = weights.transforms()
104+
105+ # Apply it to the input image
106+ img_transformed = preprocess(img)
107+
108+
109+ Some models use modules which have different training and evaluation
110+ behavior, such as batch normalization. To switch between these modes, use
111+ ``model.train() `` or ``model.eval() `` as appropriate. See
112+ :meth: `~torch.nn.Module.train ` or :meth: `~torch.nn.Module.eval ` for details.
113+
114+ .. code :: python
115+
116+ # Initialize model
117+ weights = ResNet50_Weights.DEFAULT
118+ model = resnet50(weights = weights)
119+
120+ # Set model to eval mode
121+ model.eval()
122+
70123
71124 Classification
72125==============
@@ -128,10 +181,12 @@ Here is an example of how to use the pre-trained image classification models:
128181 category_name = weights.meta[" categories" ][class_id]
129182 print (f " { category_name} : { 100 * score:.1f } % " )
130183
184+ The classes of the pre-trained model outputs can be found at ``weights.meta["categories"] ``.
185+
131186Table of all available classification weights
132187---------------------------------------------
133188
134- Accuracies are reported on ImageNet
189+ Accuracies are reported on ImageNet-1K using single crops:
135190
136191.. include :: generated/classification_table.rst
137192
@@ -140,7 +195,7 @@ Quantized models
140195
141196.. currentmodule :: torchvision.models.quantization
142197
143- The following quantized classification models are available , with or without
198+ The following architectures provide support for INT8 quantized models , with or without
144199pre-trained weights:
145200
146201.. toctree ::
@@ -181,11 +236,13 @@ Here is an example of how to use the pre-trained quantized image classification
181236 category_name = weights.meta[" categories" ][class_id]
182237 print (f " { category_name} : { 100 * score} % " )
183238
239+ The classes of the pre-trained model outputs can be found at ``weights.meta["categories"] ``.
240+
184241
185242Table of all available quantized classification weights
186243^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
187244
188- Accuracies are reported on ImageNet
245+ Accuracies are reported on ImageNet-1K using single crops:
189246
190247.. include :: generated/classification_quant_table.rst
191248
@@ -234,11 +291,14 @@ Here is an example of how to use the pre-trained semantic segmentation models:
234291 mask = normalized_masks[0 , class_to_idx[" dog" ]]
235292 to_pil_image(mask).show()
236293
294+ The classes of the pre-trained model outputs can be found at ``weights.meta["categories"] ``.
295+ The output format of the models is illustrated in :ref: `semantic_seg_output `.
296+
237297
238298Table of all available semantic segmentation weights
239299----------------------------------------------------
240300
241- All models are evaluated on COCO val2017:
301+ All models are evaluated a subset of COCO val2017, on the 20 categories that are present in the Pascal VOC dataset :
242302
243303.. include :: generated/segmentation_table.rst
244304
@@ -247,6 +307,11 @@ All models are evaluated on COCO val2017:
247307Object Detection, Instance Segmentation and Person Keypoint Detection
248308=====================================================================
249309
310+ The pre-trained models for detection, instance segmentation and
311+ keypoint detection are initialized with the classification models
312+ in torchvision. The models expect a list of ``Tensor[C, H, W] ``.
313+ Check the constructor of the models for more information.
314+
250315Object Detection
251316----------------
252317
@@ -299,10 +364,13 @@ Here is an example of how to use the pre-trained object detection models:
299364 im = to_pil_image(box.detach())
300365 im.show()
301366
367+ The classes of the pre-trained model outputs can be found at ``weights.meta["categories"] ``.
368+ For details on how to plot the bounding boxes of the models, you may refer to :ref: `instance_seg_output `.
369+
302370Table of all available Object detection weights
303371^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
304372
305- Box MAPs are reported on COCO
373+ Box MAPs are reported on COCO val2017:
306374
307375.. include :: generated/detection_table.rst
308376
@@ -319,10 +387,15 @@ weights:
319387
320388 models/mask_rcnn
321389
390+ |
391+
392+
393+ For details on how to plot the masks of the models, you may refer to :ref: `instance_seg_output `.
394+
322395Table of all available Instance segmentation weights
323396^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
324397
325- Box and Mask MAPs are reported on COCO
398+ Box and Mask MAPs are reported on COCO val2017:
326399
327400.. include :: generated/instance_segmentation_table.rst
328401
@@ -331,18 +404,23 @@ Keypoint Detection
331404
332405.. currentmodule :: torchvision.models.detection
333406
334- The following keypoint detection models are available, with or without
407+ The following person keypoint detection models are available, with or without
335408pre-trained weights:
336409
337410.. toctree ::
338411 :maxdepth: 1
339412
340413 models/keypoint_rcnn
341414
415+ |
416+
417+ The classes of the pre-trained model outputs can be found at ``weights.meta["keypoint_names"] ``.
418+ For details on how to plot the bounding boxes of the models, you may refer to :ref: `keypoint_output `.
419+
342420Table of all available Keypoint detection weights
343421^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
344422
345- Box and Keypoint MAPs are reported on COCO:
423+ Box and Keypoint MAPs are reported on COCO val2017 :
346424
347425.. include :: generated/detection_keypoint_table.rst
348426
@@ -391,10 +469,32 @@ Here is an example of how to use the pre-trained video classification models:
391469 category_name = weights.meta[" categories" ][label]
392470 print (f " { category_name} : { 100 * score} % " )
393471
472+ The classes of the pre-trained model outputs can be found at ``weights.meta["categories"] ``.
473+
394474
395475Table of all available video classification weights
396476---------------------------------------------------
397477
398- Accuracies are reported on Kinetics-400
478+ Accuracies are reported on Kinetics-400 using single crops for clip length 16:
399479
400480.. include :: generated/video_table.rst
481+
482+ Using models from Hub
483+ =====================
484+
485+ Most pre-trained models can be accessed directly via PyTorch Hub without having TorchVision installed:
486+
487+ .. code :: python
488+
489+ import torch
490+
491+ # Option 1: passing weights param as string
492+ model = torch.hub.load(" pytorch/vision" , " resnet50" , weights = " IMAGENET1K_V2" )
493+
494+ # Option 2: passing weights param as enum
495+ weights = torch.hub.load(" pytorch/vision" , " get_weight" , weights = " ResNet50_Weights.IMAGENET1K_V2" )
496+ model = torch.hub.load(" pytorch/vision" , " resnet50" , weights = weights)
497+
498+ The only exception to the above are the detection models included on
499+ :mod: `torchvision.models.detection `. These models require TorchVision
500+ to be installed because they depend on custom C++ operators.
0 commit comments