11from functools import partial
22from typing import Any , List , Optional
33
4- from torchvision .prototype .transforms import ImageNetEval
4+ from torchvision .prototype .transforms import ImageClassificationEval
55from torchvision .transforms .functional import InterpolationMode
66
77from ...models .convnext import ConvNeXt , CNBlockConfig
@@ -56,7 +56,7 @@ def _convnext(
5656class ConvNeXt_Tiny_Weights (WeightsEnum ):
5757 IMAGENET1K_V1 = Weights (
5858 url = "https://download.pytorch.org/models/convnext_tiny-983f1562.pth" ,
59- transforms = partial (ImageNetEval , crop_size = 224 , resize_size = 236 ),
59+ transforms = partial (ImageClassificationEval , crop_size = 224 , resize_size = 236 ),
6060 meta = {
6161 ** _COMMON_META ,
6262 "num_params" : 28589128 ,
@@ -70,7 +70,7 @@ class ConvNeXt_Tiny_Weights(WeightsEnum):
7070class ConvNeXt_Small_Weights (WeightsEnum ):
7171 IMAGENET1K_V1 = Weights (
7272 url = "https://download.pytorch.org/models/convnext_small-0c510722.pth" ,
73- transforms = partial (ImageNetEval , crop_size = 224 , resize_size = 230 ),
73+ transforms = partial (ImageClassificationEval , crop_size = 224 , resize_size = 230 ),
7474 meta = {
7575 ** _COMMON_META ,
7676 "num_params" : 50223688 ,
@@ -84,7 +84,7 @@ class ConvNeXt_Small_Weights(WeightsEnum):
8484class ConvNeXt_Base_Weights (WeightsEnum ):
8585 IMAGENET1K_V1 = Weights (
8686 url = "https://download.pytorch.org/models/convnext_base-6075fbad.pth" ,
87- transforms = partial (ImageNetEval , crop_size = 224 , resize_size = 232 ),
87+ transforms = partial (ImageClassificationEval , crop_size = 224 , resize_size = 232 ),
8888 meta = {
8989 ** _COMMON_META ,
9090 "num_params" : 88591464 ,
@@ -98,7 +98,7 @@ class ConvNeXt_Base_Weights(WeightsEnum):
9898class ConvNeXt_Large_Weights (WeightsEnum ):
9999 IMAGENET1K_V1 = Weights (
100100 url = "https://download.pytorch.org/models/convnext_large-ea097f82.pth" ,
101- transforms = partial (ImageNetEval , crop_size = 224 , resize_size = 232 ),
101+ transforms = partial (ImageClassificationEval , crop_size = 224 , resize_size = 232 ),
102102 meta = {
103103 ** _COMMON_META ,
104104 "num_params" : 197767336 ,
0 commit comments